Commit 691e92e1 authored by Rick Ho's avatar Rick Ho
Browse files

fix ptr correctness

parent 7e949a62
...@@ -20,8 +20,9 @@ ...@@ -20,8 +20,9 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define MOE_DEBUG
// #define MOE_BREAKDOWN // #define MOE_BREAKDOWN
// #define MOE_DEBUG_SCATTER #define MOE_DEBUG_SCATTER
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
...@@ -56,6 +57,12 @@ void batch_gather_kernel(int wid, int* pos, ...@@ -56,6 +57,12 @@ void batch_gather_kernel(int wid, int* pos,
} }
} }
template <typename scalar_t>
scalar_t print_first_float(scalar_t* d_ptr) {
scalar_t v;
cudaMemcpy(&v, d_ptr, sizeof(scalar_t), cudaMemcpyDeviceToHost);
return v;
}
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_forward_impl( void moe_cuda_forward_impl(
...@@ -80,14 +87,10 @@ void moe_cuda_forward_impl( ...@@ -80,14 +87,10 @@ void moe_cuda_forward_impl(
scalar_t *local_input_buf, *local_output_buf; scalar_t *local_input_buf, *local_output_buf;
checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&local_input_buf,
in_feat)); sizeof(scalar_t) * batch_size * in_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
#ifdef MOE_BREAKDOWN sizeof(scalar_t) * batch_size * out_feat));
timestamp(t_malloc);
fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
1e6);
#endif
int *gate = new int[batch_size]; int *gate = new int[batch_size];
int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert]; int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
...@@ -96,12 +99,6 @@ void moe_cuda_forward_impl( ...@@ -96,12 +99,6 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
#ifdef MOE_BREAKDOWN
timestamp(t_cpy);
fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
1e6);
#endif
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]]; ++expert_count[gate[i]];
} }
...@@ -117,6 +114,10 @@ void moe_cuda_forward_impl( ...@@ -117,6 +114,10 @@ void moe_cuda_forward_impl(
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++; pos[i] = expert_ptr[gate[i]]++;
} }
for (int i = batch_size - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
...@@ -133,25 +134,16 @@ void moe_cuda_forward_impl( ...@@ -133,25 +134,16 @@ void moe_cuda_forward_impl(
} }
expert_sz += expert_n[i]; expert_sz += expert_n[i];
} }
scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
#ifdef MOE_DEBUG scalar_t *input_buf, *hidden_buf, *output_buf;
for (int i = 0; i < tot_expert; ++i) { if (expert_sz) {
fprintf(stderr, "%d %d %d\n", cm->rank, i, expert_count[i]); checkCudaErrors(cudaMalloc(&hidden_buf,
} sizeof(scalar_t) * expert_sz * hidden_feat));
if (cm->rank == 0) {
for (int i = 0; i < tot_expert; ++i) {
fprintf(stderr, "%d ",all_expert_count[i]);
}
fprintf(stderr, "\n");
} }
#endif
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_expert); timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) * fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_init, t_expert) *
1e6); 1e6);
#endif #endif
...@@ -159,40 +151,45 @@ void moe_cuda_forward_impl( ...@@ -159,40 +151,45 @@ void moe_cuda_forward_impl(
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
local_input_buf); local_input_buf);
h->sync(0); h->sync(0);
// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));
if (cm->rank > 1) {
checkCudaErrors(cudaMalloc(&input_buf, if (cm->size > 1) {
sizeof(scalar_t) * expert_sz * in_feat)); if (expert_sz) {
checkCudaErrors(cudaMalloc(&output_buf, checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * out_feat)); sizeof(scalar_t) * expert_sz * in_feat));
ncclGroupStart(); checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
}
int recv_ptr = 0; int recv_ptr = 0;
for (int i = 0; i < num_expert; ++i) { for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) { for (int j = 0; j < cm->size; ++j) {
int send_id = i + j * num_expert; int idx = i + j * num_expert;
if (expert_count[send_id]) { if (expert_count[idx]) {
ncclSend(local_input_buf + expert_ptr[send_id] * in_feat, NCCL_SAFE_CALL(ncclSend(
expert_count[send_id] * in_feat * sizeof(scalar_t), local_input_buf + expert_ptr[idx] * in_feat,
expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
cm->ncclcomm, cm->ncclcomm,
h->getStream(0)); h->getStream(0)));
} }
int recv_id = i * cm->size + j; if (all_expert_count[idx]) {
if (all_expert_count[recv_id]) { NCCL_SAFE_CALL(ncclRecv(
ncclRecv(input_buf + recv_ptr * in_feat, input_buf + recv_ptr * in_feat,
all_expert_count[recv_id] * in_feat * sizeof(scalar_t), all_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
cm->ncclcomm, cm->ncclcomm,
h->getStream(0)); h->getStream(0)));
recv_ptr += all_expert_count[recv_id]; recv_ptr += all_expert_count[idx];
} }
} }
NCCL_SAFE_CALL(ncclGroupEnd());
} }
ncclGroupEnd();
} else { } else {
input_buf = local_input_buf; input_buf = local_input_buf;
output_buf = local_output_buf;
} }
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
...@@ -202,6 +199,9 @@ void moe_cuda_forward_impl( ...@@ -202,6 +199,9 @@ void moe_cuda_forward_impl(
1e6); 1e6);
#endif #endif
h->sync(0);
// fprintf(stderr, "First %d in %.3f\n", cm->rank, print_first_float(input_buf));
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
...@@ -209,9 +209,8 @@ void moe_cuda_forward_impl( ...@@ -209,9 +209,8 @@ void moe_cuda_forward_impl(
continue; continue;
} }
#ifdef MOE_DEBUG_SCATTER #ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "gemm %d sz %d\n", i, expert_n[i]); fprintf(stderr, "worker %d gemm %d sz %d offset %d\n", cm->rank, i, expert_n[i], ptr);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_n[i], // fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
in_feat);
#endif #endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i), checkCudaErrors(cublasXgemm(h->getHandle(i),
...@@ -246,37 +245,34 @@ void moe_cuda_forward_impl( ...@@ -246,37 +245,34 @@ void moe_cuda_forward_impl(
1e6); 1e6);
#endif #endif
if (cm->rank > 1) { if (cm->size > 1) {
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
ncclGroupStart();
int send_ptr = 0; int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) { for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) { for (int j = 0; j < cm->size; ++j) {
int recv_id = i + j * num_expert; int idx = i + j * num_expert;
if (expert_count[recv_id]) { if (expert_count[idx]) {
ncclRecv(local_output_buf + expert_ptr[recv_id] * in_feat, NCCL_SAFE_CALL(ncclRecv(
expert_count[recv_id] * in_feat * sizeof(scalar_t), local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
cm->ncclcomm, cm->ncclcomm,
h->getStream(0)); h->getStream(0)));
} }
int send_id = i * cm->size + j; if (all_expert_count[idx]) {
if (all_expert_count[send_id]) { NCCL_SAFE_CALL(ncclSend(
ncclSend(output_buf + send_ptr * in_feat, output_buf + send_ptr * out_feat,
all_expert_count[send_id] * in_feat * sizeof(scalar_t), all_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
cm->ncclcomm, cm->ncclcomm,
h->getStream(0)); h->getStream(0)));
send_ptr += all_expert_count[send_id]; send_ptr += all_expert_count[idx];
} }
} }
NCCL_SAFE_CALL(ncclGroupEnd());
} }
ncclGroupEnd();
} else {
local_output_buf = output_buf;
} }
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
...@@ -292,13 +288,15 @@ void moe_cuda_forward_impl( ...@@ -292,13 +288,15 @@ void moe_cuda_forward_impl(
1e6); 1e6);
#endif #endif
cudaFree(input_buf); if (expert_sz) {
cudaFree(hidden_buf); cudaFree(hidden_buf);
cudaFree(output_buf); if (cm->size > 1) {
if (cm->rank > 1) { cudaFree(input_buf);
cudaFree(local_input_buf); cudaFree(output_buf);
cudaFree(local_output_buf); }
} }
cudaFree(local_input_buf);
cudaFree(local_output_buf);
cudaFree(d_pos); cudaFree(d_pos);
delete [] pos; delete [] pos;
delete [] gate; delete [] gate;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment