"src/vscode:/vscode.git/clone" did not exist on "f8ed456e79c726a490b5223d9ecf4bcbc1811648"
Commit 7a2ad4a1 authored by Rick Ho's avatar Rick Ho
Browse files

use expert_count array instead of expert_n

parent 307e0ad9
...@@ -58,7 +58,7 @@ void moe_cuda_expert_count_impl( ...@@ -58,7 +58,7 @@ void moe_cuda_expert_count_impl(
++expert_count[gate[i]]; ++expert_count[gate[i]];
} }
expert_ptr[0] = 0; expert_ptr[0] = 0;
for (int i = 1; i < tot_expert; ++i) { for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
} }
...@@ -67,7 +67,7 @@ void moe_cuda_expert_count_impl( ...@@ -67,7 +67,7 @@ void moe_cuda_expert_count_impl(
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++; pos[i] = expert_ptr[gate[i]]++;
} }
for (int i = tot_expert - 1; i > 0; --i) { for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1]; expert_ptr[i] = expert_ptr[i - 1];
} }
expert_ptr[0] = 0; expert_ptr[0] = 0;
...@@ -77,6 +77,8 @@ void moe_cuda_expert_count_impl( ...@@ -77,6 +77,8 @@ void moe_cuda_expert_count_impl(
delete [] expert_ptr; delete [] expert_ptr;
} }
#ifdef MOE_USE_NCCL
void moe_cuda_global_scatter() { void moe_cuda_global_scatter() {
if (cm->size > 1) { if (cm->size > 1) {
if (expert_sz) { if (expert_sz) {
...@@ -118,6 +120,40 @@ void moe_cuda_global_scatter() { ...@@ -118,6 +120,40 @@ void moe_cuda_global_scatter() {
} }
} }
void moe_cuda_global_gather() {
if (cm->size > 1) {
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) {
int idx = i + j * num_expert;
if (all_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
all_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
send_ptr += all_expert_count[idx];
}
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
}
#endif // MOE_USE_NCCL
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_local_scatter_impl( void moe_cuda_local_scatter_impl(
const scalar_t* input, const scalar_t* input,
...@@ -170,7 +206,7 @@ void moe_cuda_forward_impl( ...@@ -170,7 +206,7 @@ void moe_cuda_forward_impl(
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_n[i] == 0) { if (expert_count[i] == 0) {
continue; continue;
} }
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
...@@ -240,44 +276,12 @@ void moe_cuda_backward_impl( ...@@ -240,44 +276,12 @@ void moe_cuda_backward_impl(
grad_weight + i * in_feat * out_feat, in_feat grad_weight + i * in_feat * out_feat, in_feat
)); ));
ptr += expert_n[i]; ptr += expert_count[i];
} }
smgr->sync(num_expert); smgr->sync(num_expert);
} }
void moe_cuda_global_gather() {
if (cm->size > 1) {
int send_ptr = 0;
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) {
int idx = i + j * num_expert;
if (all_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
all_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
send_ptr += all_expert_count[idx];
}
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
}
std::vector<torch::Tensor> moe_cuda_expert_count( std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, torch::Tensor gate,
size_t num_expert) { size_t num_expert) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment