Commit 7e949a62 authored by Rick Ho's avatar Rick Ho
Browse files

fix output bug

parent 4cb75d42
...@@ -82,8 +82,6 @@ void moe_cuda_forward_impl( ...@@ -82,8 +82,6 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size *
in_feat)); in_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_malloc); timestamp(t_malloc);
...@@ -136,12 +134,8 @@ void moe_cuda_forward_impl( ...@@ -136,12 +134,8 @@ void moe_cuda_forward_impl(
expert_sz += expert_n[i]; expert_sz += expert_n[i];
} }
scalar_t *input_buf, *hidden_buf, *output_buf; scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&hidden_buf, checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat)); sizeof(scalar_t) * expert_sz * hidden_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
for (int i = 0; i < tot_expert; ++i) { for (int i = 0; i < tot_expert; ++i) {
...@@ -166,32 +160,40 @@ void moe_cuda_forward_impl( ...@@ -166,32 +160,40 @@ void moe_cuda_forward_impl(
local_input_buf); local_input_buf);
h->sync(0); h->sync(0);
ncclGroupStart(); if (cm->rank > 1) {
int recv_ptr = 0; checkCudaErrors(cudaMalloc(&input_buf,
for (int i = 0; i < num_expert; ++i) { sizeof(scalar_t) * expert_sz * in_feat));
for (int j = 0; j < cm->size; ++j) { checkCudaErrors(cudaMalloc(&output_buf,
int send_id = i + j * num_expert; sizeof(scalar_t) * expert_sz * out_feat));
if (expert_count[send_id]) { ncclGroupStart();
ncclSend(local_input_buf + expert_ptr[send_id] * in_feat, int recv_ptr = 0;
expert_count[send_id] * in_feat * sizeof(scalar_t), for (int i = 0; i < num_expert; ++i) {
ncclChar, for (int j = 0; j < cm->size; ++j) {
j, int send_id = i + j * num_expert;
cm->ncclcomm, if (expert_count[send_id]) {
h->getStream(0)); ncclSend(local_input_buf + expert_ptr[send_id] * in_feat,
} expert_count[send_id] * in_feat * sizeof(scalar_t),
int recv_id = i * cm->size + j; ncclChar,
if (all_expert_count[recv_id]) { j,
ncclRecv(input_buf + recv_ptr * in_feat, cm->ncclcomm,
all_expert_count[recv_id] * in_feat * sizeof(scalar_t), h->getStream(0));
ncclChar, }
j, int recv_id = i * cm->size + j;
cm->ncclcomm, if (all_expert_count[recv_id]) {
h->getStream(0)); ncclRecv(input_buf + recv_ptr * in_feat,
recv_ptr += all_expert_count[recv_id]; all_expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
recv_ptr += all_expert_count[recv_id];
}
} }
} }
ncclGroupEnd();
} else {
input_buf = local_input_buf;
} }
ncclGroupEnd();
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
h->sync(); h->sync();
...@@ -244,32 +246,38 @@ void moe_cuda_forward_impl( ...@@ -244,32 +246,38 @@ void moe_cuda_forward_impl(
1e6); 1e6);
#endif #endif
ncclGroupStart(); if (cm->rank > 1) {
int send_ptr = 0; checkCudaErrors(cudaMalloc(&local_output_buf,
for (int i = 0; i < num_expert; ++i) { sizeof(scalar_t) * batch_size * out_feat));
for (int j = 0; j < cm->size; ++j) { ncclGroupStart();
int recv_id = i + j * num_expert; int send_ptr = 0;
if (expert_count[recv_id]) { for (int i = 0; i < num_expert; ++i) {
ncclRecv(local_input_buf + expert_ptr[recv_id] * in_feat, for (int j = 0; j < cm->size; ++j) {
expert_count[recv_id] * in_feat * sizeof(scalar_t), int recv_id = i + j * num_expert;
ncclChar, if (expert_count[recv_id]) {
j, ncclRecv(local_output_buf + expert_ptr[recv_id] * in_feat,
cm->ncclcomm, expert_count[recv_id] * in_feat * sizeof(scalar_t),
h->getStream(0)); ncclChar,
} j,
int send_id = i * cm->size + j; cm->ncclcomm,
if (all_expert_count[send_id]) { h->getStream(0));
ncclSend(input_buf + send_ptr * in_feat, }
all_expert_count[send_id] * in_feat * sizeof(scalar_t), int send_id = i * cm->size + j;
ncclChar, if (all_expert_count[send_id]) {
j, ncclSend(output_buf + send_ptr * in_feat,
cm->ncclcomm, all_expert_count[send_id] * in_feat * sizeof(scalar_t),
h->getStream(0)); ncclChar,
send_ptr += all_expert_count[send_id]; j,
cm->ncclcomm,
h->getStream(0));
send_ptr += all_expert_count[send_id];
}
} }
} }
ncclGroupEnd();
} else {
local_output_buf = output_buf;
} }
ncclGroupEnd();
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, <<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
...@@ -287,8 +295,10 @@ void moe_cuda_forward_impl( ...@@ -287,8 +295,10 @@ void moe_cuda_forward_impl(
cudaFree(input_buf); cudaFree(input_buf);
cudaFree(hidden_buf); cudaFree(hidden_buf);
cudaFree(output_buf); cudaFree(output_buf);
cudaFree(local_input_buf); if (cm->rank > 1) {
cudaFree(local_output_buf); cudaFree(local_input_buf);
cudaFree(local_output_buf);
}
cudaFree(d_pos); cudaFree(d_pos);
delete [] pos; delete [] pos;
delete [] gate; delete [] gate;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment