Commit eb1525ec authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

updatre

parent 30283570
......@@ -78,10 +78,10 @@ void moe_cuda_forward_impl(
const size_t* gate,
const scalar_t* weight,
scalar_t* output,
size_t batch_size,
size_t top_k,
size_t in_feat,
size_t out_feat) {
const size_t batch_size,
const size_t top_k,
const size_t in_feat,
const size_t out_feat) {
cublasHandle_t handle;
......@@ -135,12 +135,12 @@ void moe_cuda_forward_impl(
int main() {
const data_t *input, *weight;
data_t *input, *weight;
data_t *output;
const size_t *gate;
size_t *gate;
checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(const data_t)));
checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(const data_t)));
checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment