Commit 7e949a62 authored by Rick Ho's avatar Rick Ho
Browse files

fix output bug

parent 4cb75d42
...@@ -82,8 +82,6 @@ void moe_cuda_forward_impl( ...@@ -82,8 +82,6 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size *
in_feat)); in_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_malloc); timestamp(t_malloc);
...@@ -136,12 +134,8 @@ void moe_cuda_forward_impl( ...@@ -136,12 +134,8 @@ void moe_cuda_forward_impl(
expert_sz += expert_n[i]; expert_sz += expert_n[i];
} }
scalar_t *input_buf, *hidden_buf, *output_buf; scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&hidden_buf, checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat)); sizeof(scalar_t) * expert_sz * hidden_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
for (int i = 0; i < tot_expert; ++i) { for (int i = 0; i < tot_expert; ++i) {
...@@ -166,6 +160,11 @@ void moe_cuda_forward_impl( ...@@ -166,6 +160,11 @@ void moe_cuda_forward_impl(
local_input_buf); local_input_buf);
h->sync(0); h->sync(0);
if (cm->rank > 1) {
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
ncclGroupStart(); ncclGroupStart();
int recv_ptr = 0; int recv_ptr = 0;
for (int i = 0; i < num_expert; ++i) { for (int i = 0; i < num_expert; ++i) {
...@@ -192,6 +191,9 @@ void moe_cuda_forward_impl( ...@@ -192,6 +191,9 @@ void moe_cuda_forward_impl(
} }
} }
ncclGroupEnd(); ncclGroupEnd();
} else {
input_buf = local_input_buf;
}
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
h->sync(); h->sync();
...@@ -244,13 +246,16 @@ void moe_cuda_forward_impl( ...@@ -244,13 +246,16 @@ void moe_cuda_forward_impl(
1e6); 1e6);
#endif #endif
if (cm->rank > 1) {
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
ncclGroupStart(); ncclGroupStart();
int send_ptr = 0; int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) { for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) { for (int j = 0; j < cm->size; ++j) {
int recv_id = i + j * num_expert; int recv_id = i + j * num_expert;
if (expert_count[recv_id]) { if (expert_count[recv_id]) {
ncclRecv(local_input_buf + expert_ptr[recv_id] * in_feat, ncclRecv(local_output_buf + expert_ptr[recv_id] * in_feat,
expert_count[recv_id] * in_feat * sizeof(scalar_t), expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
...@@ -259,7 +264,7 @@ void moe_cuda_forward_impl( ...@@ -259,7 +264,7 @@ void moe_cuda_forward_impl(
} }
int send_id = i * cm->size + j; int send_id = i * cm->size + j;
if (all_expert_count[send_id]) { if (all_expert_count[send_id]) {
ncclSend(input_buf + send_ptr * in_feat, ncclSend(output_buf + send_ptr * in_feat,
all_expert_count[send_id] * in_feat * sizeof(scalar_t), all_expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
...@@ -270,6 +275,9 @@ void moe_cuda_forward_impl( ...@@ -270,6 +275,9 @@ void moe_cuda_forward_impl(
} }
} }
ncclGroupEnd(); ncclGroupEnd();
} else {
local_output_buf = output_buf;
}
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, <<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
...@@ -287,8 +295,10 @@ void moe_cuda_forward_impl( ...@@ -287,8 +295,10 @@ void moe_cuda_forward_impl(
cudaFree(input_buf); cudaFree(input_buf);
cudaFree(hidden_buf); cudaFree(hidden_buf);
cudaFree(output_buf); cudaFree(output_buf);
if (cm->rank > 1) {
cudaFree(local_input_buf); cudaFree(local_input_buf);
cudaFree(local_output_buf); cudaFree(local_output_buf);
}
cudaFree(d_pos); cudaFree(d_pos);
delete [] pos; delete [] pos;
delete [] gate; delete [] gate;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment