Commit eb1525ec authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

updatre

parent 30283570
...@@ -78,10 +78,10 @@ void moe_cuda_forward_impl( ...@@ -78,10 +78,10 @@ void moe_cuda_forward_impl(
const size_t* gate, const size_t* gate,
const scalar_t* weight, const scalar_t* weight,
scalar_t* output, scalar_t* output,
size_t batch_size, const size_t batch_size,
size_t top_k, const size_t top_k,
size_t in_feat, const size_t in_feat,
size_t out_feat) { const size_t out_feat) {
cublasHandle_t handle; cublasHandle_t handle;
...@@ -135,12 +135,12 @@ void moe_cuda_forward_impl( ...@@ -135,12 +135,12 @@ void moe_cuda_forward_impl(
int main() { int main() {
const data_t *input, *weight; data_t *input, *weight;
data_t *output; data_t *output;
const size_t *gate; size_t *gate;
checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(const data_t))); checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(const data_t))); checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t))); checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t))); checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment