Commit b8a212ef authored by Rick Ho's avatar Rick Ho
Browse files

fix bugs to pass test

parent bf4388c0
...@@ -66,9 +66,8 @@ void moe_cuda_forward_impl( ...@@ -66,9 +66,8 @@ void moe_cuda_forward_impl(
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert) { const size_t num_expert,
cublasOperation_t transb) {
auto h = getCudaStreamManager(num_expert);
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_init); timestamp(t_init);
...@@ -106,6 +105,7 @@ void moe_cuda_forward_impl( ...@@ -106,6 +105,7 @@ void moe_cuda_forward_impl(
expert_ptr[0] = 0; expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) { for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size]; int *pos = new int[batch_size];
int *d_pos; int *d_pos;
...@@ -124,12 +124,12 @@ void moe_cuda_forward_impl( ...@@ -124,12 +124,12 @@ void moe_cuda_forward_impl(
#endif #endif
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
input_buf); input_buf);
h->sync(0); // smgr.sync(0);
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
h->sync(); // h->sync();
timestamp(t_scatter); timestamp(t_scatter);
fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) * fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
1e6); 1e6);
...@@ -147,7 +147,7 @@ void moe_cuda_forward_impl( ...@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
in_feat); in_feat);
#endif #endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i), checkCudaErrors(cublasXgemm(smgr.handle, // h->getHandle(i),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], in_feat, out_feat, expert_count[i], in_feat,
...@@ -167,11 +167,11 @@ void moe_cuda_forward_impl( ...@@ -167,11 +167,11 @@ void moe_cuda_forward_impl(
1e6); 1e6);
#endif #endif
h->sync(); // h->sync();
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf, <<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
output); output);
h->sync(0); // h->sync(0);
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_gather); timestamp(t_gather);
...@@ -203,8 +203,8 @@ void moe_cuda_grad_weight( ...@@ -203,8 +203,8 @@ void moe_cuda_grad_weight(
scalar_t alpha = 1, beta = 1; scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost)); checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i]))); // checkCudaErrors(cublasSetStream);
checkCudaErrors(cublasXgemm(h->handles[0], checkCudaErrors(cublasXgemm(smgr.handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
out_feat, out_feat,
...@@ -253,7 +253,8 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -253,7 +253,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert num_expert,
CUBLAS_OP_T
); );
})); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment