Commit 7a2ad4a1 authored by Rick Ho's avatar Rick Ho
Browse files

use expert_count array instead of expert_n

parent 307e0ad9
......@@ -58,7 +58,7 @@ void moe_cuda_expert_count_impl(
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < tot_expert; ++i) {
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
......@@ -67,7 +67,7 @@ void moe_cuda_expert_count_impl(
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
for (int i = tot_expert - 1; i > 0; --i) {
for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1];
}
expert_ptr[0] = 0;
......@@ -77,6 +77,8 @@ void moe_cuda_expert_count_impl(
delete [] expert_ptr;
}
#ifdef MOE_USE_NCCL
void moe_cuda_global_scatter() {
if (cm->size > 1) {
if (expert_sz) {
......@@ -118,6 +120,40 @@ void moe_cuda_global_scatter() {
}
}
void moe_cuda_global_gather() {
if (cm->size > 1) {
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) {
int idx = i + j * num_expert;
if (all_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
all_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
send_ptr += all_expert_count[idx];
}
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
}
#endif // MOE_USE_NCCL
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
const scalar_t* input,
......@@ -170,7 +206,7 @@ void moe_cuda_forward_impl(
scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_n[i] == 0) {
if (expert_count[i] == 0) {
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
......@@ -240,44 +276,12 @@ void moe_cuda_backward_impl(
grad_weight + i * in_feat * out_feat, in_feat
));
ptr += expert_n[i];
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
void moe_cuda_global_gather() {
if (cm->size > 1) {
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) {
int idx = i + j * num_expert;
if (all_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
all_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
send_ptr += all_expert_count[idx];
}
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
}
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate,
size_t num_expert) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment