Commit 691e92e1 authored by Rick Ho's avatar Rick Ho
Browse files

fix ptr correctness

parent 7e949a62
......@@ -20,8 +20,9 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define MOE_DEBUG
// #define MOE_BREAKDOWN
// #define MOE_DEBUG_SCATTER
#define MOE_DEBUG_SCATTER
template <typename scalar_t>
__global__
......@@ -56,6 +57,12 @@ void batch_gather_kernel(int wid, int* pos,
}
}
template <typename scalar_t>
scalar_t print_first_float(scalar_t* d_ptr) {
scalar_t v;
cudaMemcpy(&v, d_ptr, sizeof(scalar_t), cudaMemcpyDeviceToHost);
return v;
}
template <typename scalar_t>
void moe_cuda_forward_impl(
......@@ -80,14 +87,10 @@ void moe_cuda_forward_impl(
scalar_t *local_input_buf, *local_output_buf;
checkCudaErrors(cudaMalloc(&local_input_buf, sizeof(scalar_t) * batch_size *
in_feat));
#ifdef MOE_BREAKDOWN
timestamp(t_malloc);
fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
1e6);
#endif
checkCudaErrors(cudaMalloc(&local_input_buf,
sizeof(scalar_t) * batch_size * in_feat));
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
int *gate = new int[batch_size];
int *expert_count = new int[tot_expert], *expert_ptr = new int[tot_expert];
......@@ -96,12 +99,6 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
#ifdef MOE_BREAKDOWN
timestamp(t_cpy);
fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
1e6);
#endif
for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]];
}
......@@ -117,6 +114,10 @@ void moe_cuda_forward_impl(
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
for (int i = batch_size - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
......@@ -133,25 +134,16 @@ void moe_cuda_forward_impl(
}
expert_sz += expert_n[i];
}
scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
#ifdef MOE_DEBUG
for (int i = 0; i < tot_expert; ++i) {
fprintf(stderr, "%d %d %d\n", cm->rank, i, expert_count[i]);
}
if (cm->rank == 0) {
for (int i = 0; i < tot_expert; ++i) {
fprintf(stderr, "%d ",all_expert_count[i]);
}
fprintf(stderr, "\n");
scalar_t *input_buf, *hidden_buf, *output_buf;
if (expert_sz) {
checkCudaErrors(cudaMalloc(&hidden_buf,
sizeof(scalar_t) * expert_sz * hidden_feat));
}
#endif
#ifdef MOE_BREAKDOWN
timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_init, t_expert) *
1e6);
#endif
......@@ -159,40 +151,45 @@ void moe_cuda_forward_impl(
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
local_input_buf);
h->sync(0);
if (cm->rank > 1) {
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
ncclGroupStart();
// fprintf(stderr, "First %d lin %.3f\n", cm->rank, print_first_float(local_input_buf));
if (cm->size > 1) {
if (expert_sz) {
checkCudaErrors(cudaMalloc(&input_buf,
sizeof(scalar_t) * expert_sz * in_feat));
checkCudaErrors(cudaMalloc(&output_buf,
sizeof(scalar_t) * expert_sz * out_feat));
}
int recv_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) {
int send_id = i + j * num_expert;
if (expert_count[send_id]) {
ncclSend(local_input_buf + expert_ptr[send_id] * in_feat,
expert_count[send_id] * in_feat * sizeof(scalar_t),
int idx = i + j * num_expert;
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
h->getStream(0)));
}
int recv_id = i * cm->size + j;
if (all_expert_count[recv_id]) {
ncclRecv(input_buf + recv_ptr * in_feat,
all_expert_count[recv_id] * in_feat * sizeof(scalar_t),
if (all_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
all_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
recv_ptr += all_expert_count[recv_id];
h->getStream(0)));
recv_ptr += all_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
ncclGroupEnd();
} else {
input_buf = local_input_buf;
output_buf = local_output_buf;
}
#ifdef MOE_BREAKDOWN
......@@ -202,6 +199,9 @@ void moe_cuda_forward_impl(
1e6);
#endif
h->sync(0);
// fprintf(stderr, "First %d in %.3f\n", cm->rank, print_first_float(input_buf));
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
......@@ -209,9 +209,8 @@ void moe_cuda_forward_impl(
continue;
}
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "gemm %d sz %d\n", i, expert_n[i]);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_n[i],
in_feat);
fprintf(stderr, "worker %d gemm %d sz %d offset %d\n", cm->rank, i, expert_n[i], ptr);
// fprintf(stderr, "worker %d GeMM %d x %d x %d\n", cm->rank, out_feat, expert_n[i], in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i),
......@@ -246,37 +245,34 @@ void moe_cuda_forward_impl(
1e6);
#endif
if (cm->rank > 1) {
checkCudaErrors(cudaMalloc(&local_output_buf,
sizeof(scalar_t) * batch_size * out_feat));
ncclGroupStart();
if (cm->size > 1) {
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) {
int recv_id = i + j * num_expert;
if (expert_count[recv_id]) {
ncclRecv(local_output_buf + expert_ptr[recv_id] * in_feat,
expert_count[recv_id] * in_feat * sizeof(scalar_t),
int idx = i + j * num_expert;
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
h->getStream(0)));
}
int send_id = i * cm->size + j;
if (all_expert_count[send_id]) {
ncclSend(output_buf + send_ptr * in_feat,
all_expert_count[send_id] * in_feat * sizeof(scalar_t),
if (all_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
all_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0));
send_ptr += all_expert_count[send_id];
h->getStream(0)));
send_ptr += all_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
ncclGroupEnd();
} else {
local_output_buf = output_buf;
}
batch_gather_kernel<scalar_t>
......@@ -292,13 +288,15 @@ void moe_cuda_forward_impl(
1e6);
#endif
cudaFree(input_buf);
cudaFree(hidden_buf);
cudaFree(output_buf);
if (cm->rank > 1) {
cudaFree(local_input_buf);
cudaFree(local_output_buf);
if (expert_sz) {
cudaFree(hidden_buf);
if (cm->size > 1) {
cudaFree(input_buf);
cudaFree(output_buf);
}
}
cudaFree(local_input_buf);
cudaFree(local_output_buf);
cudaFree(d_pos);
delete [] pos;
delete [] gate;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment