Commit ad651f03 authored by Rick Ho's avatar Rick Ho
Browse files

fix potential stream synchronization issue

parent 6c68b56b
...@@ -6,9 +6,19 @@ ...@@ -6,9 +6,19 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "smart_schedule.h" #include "smart_schedule.h"
#include "status.h"
long pipeline_gran = -1; long pipeline_gran = -1;
int smart_sch_enabled = 0;
int isSmartSchEnabled() {
return smart_sch_enabled;
}
void setSmartSchEnabled(int s) {
smart_sch_enabled = s;
}
std::vector<torch::Tensor> _smart_sch_forward( std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
...@@ -24,6 +34,7 @@ std::vector<torch::Tensor> _smart_sch_forward( ...@@ -24,6 +34,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
} else { } else {
pipeline_gran = 4; pipeline_gran = 4;
} }
setSmartSchEnabled(1);
} }
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
......
...@@ -76,7 +76,8 @@ void _compute_ptrs(long num_expert, long rank, long world_size, ...@@ -76,7 +76,8 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
template<typename scalar_t> template<typename scalar_t>
void _compute_fn(py::function fn, c10::Device device, void _compute_fn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf, scalar_t* inp_buf, scalar_t* out_buf,
int ei, long step, long offset, long micro_batch_size, long d_model) { int ei, long step, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) {
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(c10::CppTypeToScalarType<scalar_t>::value) .dtype(c10::CppTypeToScalarType<scalar_t>::value)
.device(device) .device(device)
...@@ -85,7 +86,9 @@ void _compute_fn(py::function fn, c10::Device device, ...@@ -85,7 +86,9 @@ void _compute_fn(py::function fn, c10::Device device,
{micro_batch_size, d_model}, options); {micro_batch_size, d_model}, options);
auto oup = torch::from_blob(out_buf + offset * d_model, auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options); {micro_batch_size, d_model}, options);
smgr->use_default = true;
fn(inp, oup, step); fn(inp, oup, step);
smgr->use_default = false;
} }
...@@ -156,7 +159,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -156,7 +159,7 @@ void fmoe_cuda_fused_forward_impl(
_compute_fn(forward_fn, device, _compute_fn(forward_fn, device,
global_input_buf, global_output_buf, global_input_buf, global_output_buf,
ei, step, offset, micro_batch_size, d_model); ei, step, offset, micro_batch_size, d_model, smgr);
} }
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEventRecord(output_ready[step], stream); cudaEventRecord(output_ready[step], stream);
...@@ -286,7 +289,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -286,7 +289,7 @@ void fmoe_cuda_fused_backward_impl(
_compute_fn(backward_fn, device, _compute_fn(backward_fn, device,
global_grad_out, global_grad_in, global_grad_out, global_grad_in,
ei, step, offset, micro_batch_size, d_model); ei, step, offset, micro_batch_size, d_model, smgr);
} }
// TODO: get pytorch's compute stream // TODO: get pytorch's compute stream
} }
......
#pragma once
#ifndef FASTER_STATUS_H
#define FASTER_STATUS_H
int isSmartSchEnabled();
void setSmartSchEnabled(int);
#endif // FASTER_STATUS_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment