Commit b8a212ef authored by Rick Ho's avatar Rick Ho
Browse files

fix bugs to pass test

parent bf4388c0
......@@ -66,9 +66,8 @@ void moe_cuda_forward_impl(
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert) {
auto h = getCudaStreamManager(num_expert);
const size_t num_expert,
cublasOperation_t transb) {
#ifdef MOE_BREAKDOWN
timestamp(t_init);
......@@ -106,6 +105,7 @@ void moe_cuda_forward_impl(
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size];
int *d_pos;
......@@ -124,12 +124,12 @@ void moe_cuda_forward_impl(
#endif
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
input_buf);
h->sync(0);
// smgr.sync(0);
#ifdef MOE_BREAKDOWN
h->sync();
// h->sync();
timestamp(t_scatter);
fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
1e6);
......@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i),
checkCudaErrors(cublasXgemm(smgr.handle, // h->getHandle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], in_feat,
......@@ -167,11 +167,11 @@ void moe_cuda_forward_impl(
1e6);
#endif
h->sync();
// h->sync();
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf,
<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
output);
h->sync(0);
// h->sync(0);
#ifdef MOE_BREAKDOWN
timestamp(t_gather);
......@@ -203,8 +203,8 @@ void moe_cuda_grad_weight(
scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handles[0],
// checkCudaErrors(cublasSetStream);
checkCudaErrors(cublasXgemm(smgr.handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
out_feat,
......@@ -253,7 +253,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
batch_size,
in_feat,
out_feat,
num_expert
num_expert,
CUBLAS_OP_T
);
}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment