Unverified Commit 1f82fb16 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #173 from laekov/fit-new-smgr

Fit old code with new smgr
parents 2bd187cb 945004e7
......@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once(
}
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
smgr->syncTorch();
long *gec = _d2h(d_gec, n_worker);
/* Limit number of incoming samples */
......@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once(
/* Send limit information back */
_h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
smgr->syncTorch();
_d2h(d_lec, lec, n_worker);
auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->stream());
smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
_d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream());
smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
auto gcap = _d2h(d_gcap, n_worker);
/* Re-assign and update counters */
......
......@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
ec, cap, eca, n_expert, n_worker);
smgr->sync(1);
}
__global__
......@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
);
smgr->sync(1);
}
......@@ -44,10 +44,9 @@ void _reduce_grad(
long expert_size) {
auto smgr = getCudaStreamManager(t.device().index());
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream);
cudaEventRecord(evt_stash, smgr->torchStream());
FMOE_SWE(smgr->stream(0), evt_stash);
cudaEventDestroy(evt_stash);
......
......@@ -122,8 +122,7 @@ void fmoe_cuda_fused_forward_impl(
long d_model,
long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaStreamSynchronize(torch_stream);
smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
......@@ -192,7 +191,7 @@ void fmoe_cuda_fused_forward_impl(
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
......@@ -203,14 +202,14 @@ void fmoe_cuda_fused_forward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream);
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
}
// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(torch_stream, evt_shadow[si]);
FMOE_SWE(smgr->torchStream(), evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
......@@ -282,8 +281,7 @@ void fmoe_cuda_fused_backward_impl(
long d_model,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaStreamSynchronize(torch_stream);
smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
......@@ -350,7 +348,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
......@@ -362,7 +360,7 @@ void fmoe_cuda_fused_backward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream);
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
}
// Collect gradients for shadowed experts
......@@ -370,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
if (stored_models[i]) {
if (i / num_expert == rank) {
FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert);
}
++si;
......
......@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
torch::Tensor _expert_exchange(
......
......@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
......@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
template<typename scalar_t>
......@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
......@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
......
......@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager* smgr) {
size_t numel = batch_size * topk;
assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>
<<<CEIL(numel, 256), 256, 0, smgr->torchStream()>>>
(cum_count, gate, pos, numel, topk);
smgr->sync(1);
}
#define PERTHREAD_EXPERTS 256
......@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
const size_t batch_size, const size_t n_expert,
CudaStreamManager* smgr) {
expert_count_kernel
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->stream(0)>>>
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->torchStream()>>>
(gate_idx, expert_count, batch_size, n_expert);
smgr->sync(1);
}
......@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = has_bias ? 1 : 0;
smgr->syncTorch();
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
......@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
smgr->syncTorch();
scalar_t alpha = 1, beta = 0;
// bias
......
......@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
cudaStream_t CudaStreamManager::torchStream() {
return c10::cuda::getCurrentCUDAStream().stream();
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (this->use_default) {
return at::cuda::getCurrentCUDABlasHandle();
......@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
}
void CudaStreamManager::syncTorch() {
cudaStreamSynchronize(this->torchStream());
}
void CudaStreamManager::sync(int idx) {
if (this->use_default) {
return;
......
......@@ -34,8 +34,10 @@ public:
void setup(int);
void sync(int=0);
void syncTorch();
void destroy();
cudaStream_t torchStream();
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
......
......@@ -37,7 +37,7 @@ class MoEForward(Function):
try:
# To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
except Exception as e:
# Ignore the error and fall back for compatibility to older
# versions of PyTorch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment