Commit 25e9324f authored by Rick Ho's avatar Rick Ho
Browse files

cudaStreamWaitEvent compat

parent c7e6a3db
...@@ -48,7 +48,7 @@ void _reduce_grad( ...@@ -48,7 +48,7 @@ void _reduce_grad(
cudaEvent_t evt_stash; cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash); cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream); cudaEventRecord(evt_stash, torch_stream);
cudaStreamWaitEvent(smgr->stream(0), evt_stash, 0); FMOE_SWE(smgr->stream(0), evt_stash);
cudaEventDestroy(evt_stash); cudaEventDestroy(evt_stash);
auto dtype = getNcclDataType(t.scalar_type()); auto dtype = getNcclDataType(t.scalar_type());
......
...@@ -11,6 +11,11 @@ ...@@ -11,6 +11,11 @@
#include "../stream_manager.h" #include "../stream_manager.h"
#if defined(CUDA_VERSION) && (CUDA_VERSION < 110010)
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__,0)
#else
#define FMOE_SWE(__s__,__e__) cudaStreamWaitEvent(__s__,__e__)
#endif
template<typename scalar_t> template<typename scalar_t>
void exchangeWith( void exchangeWith(
...@@ -169,7 +174,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -169,7 +174,7 @@ void fmoe_cuda_fused_forward_impl(
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaEventCreate(&evt_get); cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream); cudaEventRecord(evt_get, torch_stream);
cudaStreamWaitEvent(smgr->stream(1), evt_get); FMOE_SWE(smgr->stream(1), evt_get);
cudaEventDestroy(evt_get); cudaEventDestroy(evt_get);
} }
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(), NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
...@@ -183,7 +188,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -183,7 +188,7 @@ void fmoe_cuda_fused_forward_impl(
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(torch_stream, input_ready[step], 0); FMOE_SWE(torch_stream, input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
...@@ -200,7 +205,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -200,7 +205,7 @@ void fmoe_cuda_fused_forward_impl(
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
stash_fn(params[si], si); stash_fn(params[si], si);
cudaStreamWaitEvent(torch_stream, evt_shadow[si], 0); FMOE_SWE(torch_stream, evt_shadow[si]);
long offset = local_ptr[i]; long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i]; long micro_batch_size = local_expert_count[i];
computeFn(forward_fn, device, computeFn(forward_fn, device,
...@@ -213,7 +218,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -213,7 +218,7 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n // R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0); FMOE_SWE(smgr->stream(0), output_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
...@@ -331,7 +336,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -331,7 +336,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0); FMOE_SWE(smgr->stream(1), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
...@@ -349,7 +354,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -349,7 +354,7 @@ void fmoe_cuda_fused_backward_impl(
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaStreamWaitEvent(torch_stream, evt_reduce[i % num_expert], 0); FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
set_grad_fn(si); set_grad_fn(si);
} }
++si; ++si;
...@@ -358,7 +363,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -358,7 +363,7 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n // R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0); FMOE_SWE(smgr->stream(0), output_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment