Commit 7e949a62 authored by Rick Ho's avatar Rick Ho
Browse files

fix output bug

parent 4cb75d42
......@@ -82,8 +82,6 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size *
in_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
#ifdef MOE_BREAKDOWN
timestamp(t_malloc);
......@@ -136,12 +134,8 @@ void moe_cuda_forward_impl(
expert_sz += expert_n[i];
}
scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
#ifdef MOE_DEBUG
for (int i = 0; i < tot_expert; ++i) {
......@@ -166,32 +160,40 @@ void moe_cuda_forward_impl(
local_input_buf);
h->sync(0);
ncclGroupStart();
int recv_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int send_id = i + j * num_expert;
if (expert_count[send_id]) {
ncclSend(local_input_buf + expert_ptr[send_id] * in_feat,
expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int recv_id = i * cm->size + j;
if (all_expert_count[recv_id]) {
ncclRecv(input_buf + recv_ptr * in_feat,
all_expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
recv_ptr += all_expert_count[recv_id];
if (cm->rank > 1) {
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
ncclGroupStart();
int recv_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int send_id = i + j * num_expert;
if (expert_count[send_id]) {
ncclSend(local_input_buf + expert_ptr[send_id] * in_feat,
expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int recv_id = i * cm->size + j;
if (all_expert_count[recv_id]) {
ncclRecv(input_buf + recv_ptr * in_feat,
all_expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
recv_ptr += all_expert_count[recv_id];
}
}
}
ncclGroupEnd();
} else {
input_buf = local_input_buf;
}
ncclGroupEnd();
#ifdef MOE_BREAKDOWN
h->sync();
......@@ -244,32 +246,38 @@ void moe_cuda_forward_impl(
1e6);
#endif
ncclGroupStart();
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int recv_id = i + j * num_expert;
if (expert_count[recv_id]) {
ncclRecv(local_input_buf + expert_ptr[recv_id] * in_feat,
expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int send_id = i * cm->size + j;
if (all_expert_count[send_id]) {
ncclSend(input_buf + send_ptr * in_feat,
all_expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
send_ptr += all_expert_count[send_id];
if (cm->rank > 1) {
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
ncclGroupStart();
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < cm->size; ++j) {
int recv_id = i + j * num_expert;
if (expert_count[recv_id]) {
ncclRecv(local_output_buf + expert_ptr[recv_id] * in_feat,
expert_count[recv_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
}
int send_id = i * cm->size + j;
if (all_expert_count[send_id]) {
ncclSend(output_buf + send_ptr * in_feat,
all_expert_count[send_id] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
send_ptr += all_expert_count[send_id];
}
}
}
ncclGroupEnd();
} else {
local_output_buf = output_buf;
}
ncclGroupEnd();
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
......@@ -287,8 +295,10 @@ void moe_cuda_forward_impl(
cudaFree(input_buf);
cudaFree(hidden_buf);
cudaFree(output_buf);
cudaFree(local_input_buf);
cudaFree(local_output_buf);
if (cm->rank > 1) {
cudaFree(local_input_buf);
cudaFree(local_output_buf);
}
cudaFree(d_pos);
delete [] pos;
delete [] gate;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment