Commit ad651f03 authored by Rick Ho's avatar Rick Ho
Browse files

fix potential stream synchronization issue

parent 6c68b56b
......@@ -6,9 +6,19 @@
#include <c10/cuda/CUDAGuard.h>
#include "smart_schedule.h"
#include "status.h"
long pipeline_gran = -1;
int smart_sch_enabled = 0;
int isSmartSchEnabled() {
return smart_sch_enabled;
}
void setSmartSchEnabled(int s) {
smart_sch_enabled = s;
}
std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
......@@ -24,6 +34,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
} else {
pipeline_gran = 4;
}
setSmartSchEnabled(1);
}
auto smgr = getCudaStreamManager(input_buf.device().index());
......
......@@ -76,7 +76,8 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
template<typename scalar_t>
void _compute_fn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
int ei, long step, long offset, long micro_batch_size, long d_model) {
int ei, long step, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) {
auto options = torch::TensorOptions()
.dtype(c10::CppTypeToScalarType<scalar_t>::value)
.device(device)
......@@ -85,7 +86,9 @@ void _compute_fn(py::function fn, c10::Device device,
{micro_batch_size, d_model}, options);
auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options);
smgr->use_default = true;
fn(inp, oup, step);
smgr->use_default = false;
}
......@@ -156,7 +159,7 @@ void fmoe_cuda_fused_forward_impl(
_compute_fn(forward_fn, device,
global_input_buf, global_output_buf,
ei, step, offset, micro_batch_size, d_model);
ei, step, offset, micro_batch_size, d_model, smgr);
}
auto stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEventRecord(output_ready[step], stream);
......@@ -286,7 +289,7 @@ void fmoe_cuda_fused_backward_impl(
_compute_fn(backward_fn, device,
global_grad_out, global_grad_in,
ei, step, offset, micro_batch_size, d_model);
ei, step, offset, micro_batch_size, d_model, smgr);
}
// TODO: get pytorch's compute stream
}
......
#pragma once
#ifndef FASTER_STATUS_H
#define FASTER_STATUS_H
int isSmartSchEnabled();
void setSmartSchEnabled(int);
#endif // FASTER_STATUS_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment