Commit 771dc62d authored by Rick Ho's avatar Rick Ho
Browse files

forward code in smart schedule

parent ad651f03
......@@ -25,8 +25,12 @@ std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long global_batch_size,
long expert_size,
long n_workers,
py::function forward_fn) {
py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
py::function pop_fn) {
if (pipeline_gran == -1) {
char* p = getenv("FMOE_FASTER_GROUP_SIZE");
if (p) {
......@@ -50,11 +54,26 @@ std::vector<torch::Tensor> _smart_sch_forward(
auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});
std::vector<torch::Tensor> params;
auto stored_models_ = stored_models.data_ptr<bool>();
for (long i = 0; i < num_expert * n_workers; ++i) {
if (stored_models_[i]) {
torch::Tensor t = input_buf.new_empty({expert_size});
if (i / num_expert == rank) {
get_param_fn(t);
}
params.push_back(t);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_smart_sch_forward", ([&] {
fmoe_cuda_fused_forward_impl(
forward_fn,
stash_fn,
pop_fn,
input_buf.device(),
params,
input_buf.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
......@@ -64,7 +83,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, num_expert, rank, n_workers,
d_model, num_expert, rank, n_workers, expert_size,
pipeline_gran, smgr);
}));
return {output_buf, global_input_buf};
......@@ -77,8 +96,13 @@ torch::Tensor _smart_sch_backward(
torch::Tensor stored_models,
long buf_batch_size,
long global_batch_size,
long expert_size,
long n_workers,
py::function backward_fn) {
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn) {
const auto num_expert = local_expert_count.size(0) / n_workers;
auto smgr = getCudaStreamManager(grad_out.device().index());
int rank;
......
......@@ -39,6 +39,7 @@ void _exchange_with(
int gidx_recv = ei * world_size + rank_recv; \
int idx_self = ei + rank * num_expert;
void _compute_ptrs(long num_expert, long rank, long world_size,
const long* local_expert_count,
const long* global_expert_count,
......@@ -73,10 +74,11 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
}
}
template<typename scalar_t>
void _compute_fn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
int ei, long step, long offset, long micro_batch_size, long d_model,
long idx, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) {
auto options = torch::TensorOptions()
.dtype(c10::CppTypeToScalarType<scalar_t>::value)
......@@ -87,7 +89,7 @@ void _compute_fn(py::function fn, c10::Device device,
auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options);
smgr->use_default = true;
fn(inp, oup, step);
fn(inp, oup, idx);
smgr->use_default = false;
}
......@@ -95,9 +97,12 @@ void _compute_fn(py::function fn, c10::Device device,
template<typename scalar_t>
void fmoe_cuda_fused_forward_impl(
py::function forward_fn,
py::function stash_fn,
py::function pop_fn,
c10::Device device,
std::vector<torch::Tensor> params,
const scalar_t* input_buf,
scalar_t* input_buf,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
......@@ -107,8 +112,9 @@ void fmoe_cuda_fused_forward_impl(
const bool* stored_models,
long d_model,
long num_expert, long rank, long world_size,
long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
......@@ -130,8 +136,9 @@ void fmoe_cuda_fused_forward_impl(
cudaEventCreate(output_ready + i);
}
// S_0 ... S_n
for (long step = 0; step < n_groups; ++step) {
for (int ei = 0; ei < num_expert; ++ei) {
for (long ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
......@@ -149,8 +156,30 @@ void fmoe_cuda_fused_forward_impl(
cudaEventRecord(input_ready[step], smgr->stream(0));
}
// Broadcast shadowed experts
cudaEvent_t evt_get, *evt_shadow;
if (params.size() > 0) {
evt_shadow = new cudaEvent_t[params.size()];
}
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream);
cudaStreamWaitEvent(smgr->stream(1), evt_get);
}
NCCL_SAFE_CALL(ncclBcast(params[si].data_ptr<void>(),
expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0)));
cudaEventCreate(evt_shadow + si);
cudaEventRecord(evt_shadow[si], smgr->stream(0));
++si;
}
}
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0);
cudaStreamWaitEvent(torch_stream, input_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
......@@ -159,12 +188,27 @@ void fmoe_cuda_fused_forward_impl(
_compute_fn(forward_fn, device,
global_input_buf, global_output_buf,
ei, step, offset, micro_batch_size, d_model, smgr);
step, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
}
// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
stash_fn(params[si], si);
cudaStreamWaitEvent(torch_stream, evt_shadow[si], 0);
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
_compute_fn(forward_fn, device,
input_buf, output_buf,
n_groups + si, offset, micro_batch_size, d_model, smgr);
++si;
}
auto stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEventRecord(output_ready[step], stream);
}
pop_fn();
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
......@@ -184,31 +228,6 @@ void fmoe_cuda_fused_forward_impl(
}
}
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_forward(
input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden,
smgr->stream(stream), smgr->handle(stream));
}
}*/
delete [] local_ptr;
delete [] global_ptr;
delete [] local_global_ptr;
......@@ -217,8 +236,14 @@ void fmoe_cuda_fused_forward_impl(
cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]);
}
for (unsigned i = 0; i < params.size(); ++i) {
cudaEventDestroy(evt_shadow[i]);
}
delete [] input_ready;
delete [] output_ready;
if (params.size()) {
delete [] evt_shadow;
}
}
......@@ -238,6 +263,7 @@ void fmoe_cuda_fused_backward_impl(
long d_model,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
......@@ -289,9 +315,9 @@ void fmoe_cuda_fused_backward_impl(
_compute_fn(backward_fn, device,
global_grad_out, global_grad_in,
ei, step, offset, micro_batch_size, d_model, smgr);
step, offset, micro_batch_size, d_model, smgr);
}
// TODO: get pytorch's compute stream
cudaEventRecord(output_ready[step], torch_stream);
}
for (long step = 0; step < n_groups; ++step) {
......
......@@ -63,8 +63,13 @@ std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long global_batch_size, long n_workers,
py::function forward_fn);
long global_batch_size,
long expert_size,
long n_workers,
py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
py::function pop_fn);
torch::Tensor _smart_sch_backward(
torch::Tensor grad_out,
torch::Tensor local_expert_count,
......@@ -72,8 +77,11 @@ torch::Tensor _smart_sch_backward(
torch::Tensor stored_models,
long buf_batch_size,
long global_batch_size,
long expert_size,
long n_workers,
py::function backward_fn);
py::function backward_fn,
py::function collect_fn,
py::function set_grad_fn);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL
......
......@@ -3,21 +3,34 @@
#include <cassert>
#include <thread>
#include <iostream>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include "fastermoe/status.h"
#include "stream_manager.h"
#define SMGR_N_STREAMS 16
cudaStream_t CudaStreamManager::stream(size_t idx) {
if (this->use_default) {
return c10::cuda::getCurrentCUDAStream().stream();
}
return this->streams[idx % SMGR_N_STREAMS];
}
cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (this->use_default) {
return at::cuda::getCurrentCUDABlasHandle();
}
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::sync(int idx) {
if (this->use_default) {
return;
}
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
......
......@@ -21,13 +21,14 @@ public:
int device;
cublasHandle_t* handles;
cudaStream_t* streams;
bool use_default;
#ifdef FMOE_USE_NCCL
char ncclgood;
ncclComm_t ncclcomm;
#endif
public:
CudaStreamManager(int device_): device(device_) {
CudaStreamManager(int device_): device(device_), use_default(false) {
this->setup(device);
}
......
import torch
def get_expert_param_size(e):
return sum(map(lambda x: x.numel(), e.parameters()))
def get_expert_params(e, out):
offset = 0
for n, p in e.named_parameters():
seg = out[offset:offset + p.numel()]
offset += p.numel()
seg.copy_(p)
def stash_expert_params(e, params):
if not hasattr(e, 'expert_param_stash'):
setattr(e, 'expert_param_stash', dict())
offset = 0
for n, p in e.named_parameters():
if n not in e.expert_param_stash:
e.expert_param_stash[n] = p.data.clone()
with torch.no_grad():
seg = params[offset:offset + p.numel()]
offset += p.numel()
p.copy_(seg.reshape(p.shape))
def pop_expert_params(e):
for n, p in e.named_parameters():
with torch.no_grad():
p.copy_(e.expert_param_stash[n])
e.expert_param_stash.clear()
def collect_expert_grads(e, grads):
offset = 0
for _, p in e.named_parameters():
seg = grads[offset:offset + p.numel()]
offset += p.numel()
if p.grad is not None:
seg.copy_(p.grad)
p.grad = None
else:
seg.zero_()
def set_grads(e, grads):
offset = 0
for n, p in e.named_parameters():
seg = grads[offset:offset + p.numel()]
offset += p.numel()
if p.grad is None:
p.grad = seg.clone()
else:
p.grad += seg
......@@ -7,6 +7,7 @@ from torch.autograd.function import Function
from fmoe.functions import prepare_forward, ensure_comm
from fmoe.functions import _local_scatter, _local_gather
import fmoe_cuda as fmoe_native
import expert_utils
class MoEForward(Function):
......@@ -14,6 +15,7 @@ class MoEForward(Function):
def forward(
ctx,
expert_fn,
experts,
inp, # models,
pos_s, pos_g,
local_expert_count, global_expert_count,
......@@ -25,8 +27,8 @@ class MoEForward(Function):
# TODO: leave this for furture work of expert shadowing
# model_params = [[tuple(m.parameters()) for m in node] for node in models]
ctx.gibs = [None] * world_size
ctx.gobs = [None] * world_size
ctx.gibs = [None] * (world_size * 2)
ctx.gobs = [None] * (world_size * 2)
def _expert_forward(x, y, idx):
x = x.data
with torch.enable_grad():
......@@ -36,11 +38,23 @@ class MoEForward(Function):
ctx.gobs[idx] = y0
y.copy_(y0)
ctx.experts = experts
if stored_models.any():
ctx.expert_size = expert_utils.get_expert_param_size(experts)
else:
ctx.expert_size = 0
get_param_fn = lambda out: expert_utils.get_expert_params(experts, out)
pop_fn = lambda: expert_utils.pop_expert_params(experts)
ctx.shadows = [None] * world_size
def stash_fn(params, idx):
expert_utils.stash_expert_params(experts, p)
ctx.shadows[idx] = params
local_output_buf, gib = fmoe_native.smart_sch_forward(
local_input_buf,
local_expert_count, global_expert_count,
stored_models, fwd_batch_size,
world_size, _expert_forward)
stored_models, fwd_batch_size, ctx.expert_size,
world_size, _expert_forward, get_param_fn, stash_fn, pop_fn)
out = _local_gather(local_output_buf, pos_g, out_batch_size,
maybe_overlap=False)
......@@ -65,19 +79,27 @@ class MoEForward(Function):
x = ctx.gibs[idx]
grad_x.copy_(x.grad)
experts = ctx.experts
def stash_fn(idx):
expert_utils.stash_expert_params(experts, ctx.shadows[idx])
pop_fn = lambda: expert_utils.pop_expert_params(experts)
collect_fn = lambda g: expert_utils.collect_expert_grads(experts, g)
set_grad_fn = lambda g: expert_utils.set_grads(experts, g)
grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
grad_in_buf = fmoe_native.smart_sch_backward(
grad_out_buf,
local_expert_count, global_expert_count,
stored_models,
pos_s.shape[0], fwd_batch_size,
world_size, _expert_backward)
pos_s.shape[0], fwd_batch_size, ctx.expert_size,
world_size, _expert_backward,
stash_fn, pop_fn, collect_fn, set_grad_fn)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
return (None, grad_in, None, None, None, None, None, None, None, None)
return (None, None, grad_in, None, None, None, None, None, None, None, None)
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None):
# TODO: Using multiple tensors as input is to be supported.
assert(isinstance(inp, torch.Tensor))
# TODO: Support many experts on each process
......@@ -98,7 +120,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
topk = gate.shape[1]
out_batch_size = inp.shape[0] * topk
return MoEForward.apply(expert_fn, inp,
return MoEForward.apply(expert_fn, experts, inp,
torch.div(pos, topk, rounding_mode='floor'), pos,
local_expert_count, global_expert_count, stored_models,
fwd_batch_size, out_batch_size, world_size)
......@@ -21,7 +21,7 @@ def mark_module_parallel_comm(module, comm):
setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, **kwargs):
r"""
A private function that performs the following steps to complete the MoE
computation.
......@@ -227,7 +227,9 @@ class FMoE(nn.Module):
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward(
moe_inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
moe_inp, gate_top_k_idx, self.expert_fn,
self.num_expert, self.world_size,
experts=self.experts
)
# recover deleted tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment