Commit 6cb6bbe4 authored by Rick Ho's avatar Rick Ho
Browse files

global exchange update variable name

parent bb92d30e
......@@ -7,14 +7,14 @@
std::vector<torch::Tensor> _expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
long n_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto smgr = getCudaStreamManager(local_expert_count.device().index());
fmoe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
num_expert, n_workers,
n_expert, n_workers,
smgr);
return {global_expert_count};
}
......@@ -26,7 +26,7 @@ std::vector<torch::Tensor> _global_scatter(
long batch_size, long n_workers) {
CHECK_INPUT(input_buf);
auto num_expert = local_expert_count.size(0) / n_workers;
auto n_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
......@@ -38,7 +38,7 @@ std::vector<torch::Tensor> _global_scatter(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
in_feat, n_expert, n_workers,
smgr
);
}));
......@@ -52,7 +52,7 @@ std::vector<torch::Tensor> _global_gather(
long batch_size, long n_workers) {
CHECK_INPUT(output_buf);
auto num_expert = local_expert_count.size(0) / n_workers;
auto n_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
......@@ -64,7 +64,7 @@ std::vector<torch::Tensor> _global_gather(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
out_feat, n_expert, n_workers,
smgr
);
}));
......
......@@ -4,20 +4,20 @@
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int num_expert, int world_size,
int n_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + num_expert * i,
num_expert,
local_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + num_expert * i,
num_expert,
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
......@@ -33,21 +33,21 @@ void fmoe_cuda_global_scatter_impl(
const long* local_expert_count,
const long* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
size_t in_feat, size_t n_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
long*expert_ptr = new long[num_expert * world_size];
long*expert_ptr = new long[n_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < num_expert * world_size; ++i) {
for (size_t i = 1; i < n_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < num_expert; ++i) {
for (size_t i = 0; i < n_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
int idx = i + j * n_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
......@@ -80,20 +80,20 @@ void fmoe_cuda_global_gather_impl(
const long* local_expert_count,
const long* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
size_t out_feat, size_t n_expert, size_t world_size,
CudaStreamManager* smgr) {
long send_ptr = 0;
/* TODO: may save for backward */
long *expert_ptr = new long[num_expert * world_size];
long *expert_ptr = new long[n_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < num_expert * world_size; ++i) {
for (size_t i = 1; i < n_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < num_expert; ++i) {
for (size_t i = 0; i < n_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
int idx = i + j * n_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment