Commit 4f9f77f8 authored by Rick Ho's avatar Rick Ho
Browse files

use torchstream everywhere

parent 2bd187cb
......@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
ec, cap, eca, n_expert, n_worker);
smgr->sync(1);
}
__global__
......@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
);
smgr->sync(1);
}
......@@ -122,8 +122,7 @@ void fmoe_cuda_fused_forward_impl(
long d_model,
long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaStreamSynchronize(torch_stream);
smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
......@@ -192,7 +191,7 @@ void fmoe_cuda_fused_forward_impl(
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
......@@ -203,14 +202,14 @@ void fmoe_cuda_fused_forward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream);
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
}
// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(torch_stream, evt_shadow[si]);
FMOE_SWE(smgr->torchStream(), evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
......@@ -282,8 +281,7 @@ void fmoe_cuda_fused_backward_impl(
long d_model,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaStreamSynchronize(torch_stream);
smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
......@@ -350,7 +348,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
......@@ -362,7 +360,7 @@ void fmoe_cuda_fused_backward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream);
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
}
// Collect gradients for shadowed experts
......@@ -370,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
if (stored_models[i]) {
if (i / num_expert == rank) {
FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert);
}
++si;
......
......@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
torch::Tensor _expert_exchange(
......
......@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
......@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
template<typename scalar_t>
......@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
......@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
......
......@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager* smgr) {
size_t numel = batch_size * topk;
assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>
<<<CEIL(numel, 256), 256, 0, smgr->torchStream()>>>
(cum_count, gate, pos, numel, topk);
smgr->sync(1);
}
#define PERTHREAD_EXPERTS 256
......@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
const size_t batch_size, const size_t n_expert,
CudaStreamManager* smgr) {
expert_count_kernel
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->stream(0)>>>
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->torchStream()>>>
(gate_idx, expert_count, batch_size, n_expert);
smgr->sync(1);
}
......@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = has_bias ? 1 : 0;
smgr->syncTorch();
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
......@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
smgr->syncTorch();
scalar_t alpha = 1, beta = 0;
// bias
......
......@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS];
}
cudaStream_t CudaStreamManager::torchStream() {
return c10::cuda::getCurrentCUDAStream().stream();
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (this->use_default) {
return at::cuda::getCurrentCUDABlasHandle();
......@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
}
void CudaStreamManager::syncTorch() {
cudaStreamSynchronize(this->torchStream());
}
void CudaStreamManager::sync(int idx) {
if (this->use_default) {
return;
......
......@@ -34,8 +34,10 @@ public:
void setup(int);
void sync(int=0);
void syncTorch();
void destroy();
cudaStream_t torchStream();
cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment