Commit 14878015 authored by zms1999's avatar zms1999
Browse files

support n_expert > 1 for FasterMoE smart scheduling and expert shadowing

parent 698a12ae
...@@ -104,7 +104,7 @@ std::vector<torch::Tensor> _smart_sch_forward( ...@@ -104,7 +104,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
if (stored_models_[i]) { if (stored_models_[i]) {
torch::Tensor t = input_buf.new_empty({expert_size}); torch::Tensor t = input_buf.new_empty({expert_size});
if (i / num_expert == rank) { if (i / num_expert == rank) {
get_param_fn(t); get_param_fn(t, i % num_expert);
} }
params.push_back(t); params.push_back(t);
} }
......
...@@ -83,7 +83,7 @@ void computePtrs(long num_expert, long rank, long world_size, ...@@ -83,7 +83,7 @@ void computePtrs(long num_expert, long rank, long world_size,
template<typename scalar_t> template<typename scalar_t>
void computeFn(py::function fn, c10::Device device, void computeFn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf, scalar_t* inp_buf, scalar_t* out_buf,
long idx, long offset, long micro_batch_size, long d_model, long expert_idx, long store_idx, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
if(micro_batch_size == 0) { if(micro_batch_size == 0) {
return; return;
...@@ -97,7 +97,7 @@ void computeFn(py::function fn, c10::Device device, ...@@ -97,7 +97,7 @@ void computeFn(py::function fn, c10::Device device,
auto oup = torch::from_blob(out_buf + offset * d_model, auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options); {micro_batch_size, d_model}, options);
smgr->use_default = true; smgr->use_default = true;
fn(inp, oup, idx); fn(inp, oup, expert_idx, store_idx);
smgr->use_default = false; smgr->use_default = false;
} }
...@@ -174,7 +174,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -174,7 +174,7 @@ void fmoe_cuda_fused_forward_impl(
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaEventCreate(&evt_get); cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream); cudaEventRecord(evt_get, torch_stream);
FMOE_SWE(smgr->stream(1), evt_get); FMOE_SWE(smgr->stream(0), evt_get);
cudaEventDestroy(evt_get); cudaEventDestroy(evt_get);
} }
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(), NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
...@@ -196,7 +196,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -196,7 +196,7 @@ void fmoe_cuda_fused_forward_impl(
(from_base + pipeline_gran)] - offset; (from_base + pipeline_gran)] - offset;
computeFn(forward_fn, device, computeFn(forward_fn, device,
global_input_buf, global_output_buf, global_input_buf, global_output_buf,
step, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], torch_stream); cudaEventRecord(output_ready[step], torch_stream);
} }
...@@ -204,17 +204,17 @@ void fmoe_cuda_fused_forward_impl( ...@@ -204,17 +204,17 @@ void fmoe_cuda_fused_forward_impl(
// Compute over shadowed experts // Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
stash_fn(params[si], si);
FMOE_SWE(torch_stream, evt_shadow[si]); FMOE_SWE(torch_stream, evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i]; long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i]; long micro_batch_size = local_expert_count[i];
computeFn(forward_fn, device, computeFn(forward_fn, device,
input_buf, output_buf, input_buf, output_buf,
n_groups + si, offset, micro_batch_size, d_model, smgr); 0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
++si; ++si;
} }
} }
pop_fn(); pop_fn(0);
// R_0 ... R_n // R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
...@@ -319,13 +319,13 @@ void fmoe_cuda_fused_backward_impl( ...@@ -319,13 +319,13 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert]; cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert];
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
stash_fn(si); stash_fn(si, 0);
long offset = local_ptr[i]; long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i]; long micro_batch_size = local_expert_count[i];
computeFn(backward_fn, device, computeFn(backward_fn, device,
grad_out, grad_in, grad_out, grad_in,
n_groups + si, offset, micro_batch_size, d_model, smgr); 0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
collect_fn(si, i / num_expert); collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert); cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0)); cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
...@@ -333,11 +333,11 @@ void fmoe_cuda_fused_backward_impl( ...@@ -333,11 +333,11 @@ void fmoe_cuda_fused_backward_impl(
++si; ++si;
} }
} }
pop_fn(); pop_fn(0);
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(1), input_ready[step]); FMOE_SWE(torch_stream, input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
...@@ -346,7 +346,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -346,7 +346,7 @@ void fmoe_cuda_fused_backward_impl(
computeFn(backward_fn, device, computeFn(backward_fn, device,
global_grad_out, global_grad_in, global_grad_out, global_grad_in,
step, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], torch_stream); cudaEventRecord(output_ready[step], torch_stream);
} }
...@@ -356,7 +356,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -356,7 +356,7 @@ void fmoe_cuda_fused_backward_impl(
if (stored_models[i]) { if (stored_models[i]) {
if (i / num_expert == rank) { if (i / num_expert == rank) {
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]); FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
set_grad_fn(si); set_grad_fn(si, i % num_expert);
} }
++si; ++si;
} }
......
import torch import torch
def get_expert_param_size(e): def get_expert_param_size(e, idx):
e = e[idx]
return sum(map(lambda x: x.numel(), e.parameters())) return sum(map(lambda x: x.numel(), e.parameters()))
def get_expert_params(e, out): def get_expert_params(e, out, idx):
e = e[idx]
offset = 0 offset = 0
for n, p in e.named_parameters(): for n, p in e.named_parameters():
seg = out[offset:offset + p.numel()] seg = out[offset:offset + p.numel()]
...@@ -13,20 +15,25 @@ def get_expert_params(e, out): ...@@ -13,20 +15,25 @@ def get_expert_params(e, out):
seg.copy_(p.data.flatten()) seg.copy_(p.data.flatten())
def stash_expert_params(e, params): def stash_expert_params(e, params, idx):
e = e[idx]
if not hasattr(e, 'expert_param_stash'): if not hasattr(e, 'expert_param_stash'):
setattr(e, 'expert_param_stash', dict()) setattr(e, 'expert_param_stash', dict())
setattr(e, 'expert_grad_stash', dict())
offset = 0 offset = 0
for n, p in e.named_parameters(): for n, p in e.named_parameters():
if n not in e.expert_param_stash: if n not in e.expert_param_stash:
e.expert_param_stash[n] = p.data.clone() e.expert_param_stash[n] = p.data.clone()
e.expert_grad_stash[n] = p.grad.clone() if p.grad is not None else None
with torch.no_grad(): with torch.no_grad():
seg = params[offset:offset + p.numel()] seg = params[offset:offset + p.numel()]
offset += p.numel() offset += p.numel()
p.copy_(seg.reshape(p.shape)) p.copy_(seg.reshape(p.shape))
p.grad = None
def pop_expert_params(e): def pop_expert_params(e, idx):
e = e[idx]
if not hasattr(e, 'expert_param_stash'): if not hasattr(e, 'expert_param_stash'):
return return
if not e.expert_param_stash: if not e.expert_param_stash:
...@@ -34,10 +41,14 @@ def pop_expert_params(e): ...@@ -34,10 +41,14 @@ def pop_expert_params(e):
for n, p in e.named_parameters(): for n, p in e.named_parameters():
with torch.no_grad(): with torch.no_grad():
p.copy_(e.expert_param_stash[n]) p.copy_(e.expert_param_stash[n])
if e.expert_grad_stash[n] is not None:
p.grad = e.expert_grad_stash[n].clone()
e.expert_param_stash.clear() e.expert_param_stash.clear()
e.expert_grad_stash.clear()
def collect_expert_grads(e, grads): def collect_expert_grads(e, grads, idx):
e = e[idx]
offset = 0 offset = 0
for _, p in e.named_parameters(): for _, p in e.named_parameters():
seg = grads[offset:offset + p.numel()] seg = grads[offset:offset + p.numel()]
...@@ -49,7 +60,8 @@ def collect_expert_grads(e, grads): ...@@ -49,7 +60,8 @@ def collect_expert_grads(e, grads):
seg.zero_() seg.zero_()
def set_grads(e, grads): def set_grads(e, grads, idx):
e = e[idx]
offset = 0 offset = 0
for n, p in e.named_parameters(): for n, p in e.named_parameters():
seg = grads[offset:offset + p.numel()] seg = grads[offset:offset + p.numel()]
......
...@@ -23,12 +23,13 @@ class MoEForward(Function): ...@@ -23,12 +23,13 @@ class MoEForward(Function):
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
stored_models, stored_models,
fwd_batch_size, out_batch_size, fwd_batch_size, out_batch_size,
num_expert,
world_size): world_size):
local_input_buf = _local_scatter(inp, pos_s) local_input_buf = _local_scatter(inp, pos_s)
ctx.gibs = [None] * (world_size * 2) ctx.gibs = [None] * (world_size * num_expert * 2)
ctx.gobs = [None] * (world_size * 2) ctx.gobs = [None] * (world_size * num_expert * 2)
def _expert_forward(x, y, idx): def _expert_forward(x, y, expert_idx, store_idx):
nothing = lambda a: a nothing = lambda a: a
x = x.data x = x.data
with torch.enable_grad(): with torch.enable_grad():
...@@ -40,22 +41,24 @@ class MoEForward(Function): ...@@ -40,22 +41,24 @@ class MoEForward(Function):
except Exception as e: except Exception as e:
# Ignore the error and fall back for compatibility to older # Ignore the error and fall back for compatibility to older
# versions of PyTorch # versions of PyTorch
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64)) y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
ctx.gibs[idx] = x ctx.gibs[store_idx] = x
ctx.gobs[idx] = y0 ctx.gobs[store_idx] = y0
y.copy_(y0) y.copy_(y0)
ctx.experts = experts ctx.experts = experts
if stored_models.any(): if stored_models.any():
ctx.expert_size = expert_utils.get_expert_param_size(experts) ctx.expert_size = expert_utils.get_expert_param_size(experts, 0)
for i in range(num_expert):
assert ctx.expert_size == expert_utils.get_expert_param_size(experts, i), "report bug"
else: else:
ctx.expert_size = 0 ctx.expert_size = 0
get_param_fn = lambda out: expert_utils.get_expert_params(experts, out) get_param_fn = lambda out, idx: expert_utils.get_expert_params(experts, out, idx)
pop_fn = lambda: expert_utils.pop_expert_params(experts) pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx)
ctx.shadows = [None] * world_size ctx.shadows = [None] * world_size * num_expert
def stash_fn(params, idx): def stash_fn(params, store_idx, expert_idx):
expert_utils.stash_expert_params(experts, params) expert_utils.stash_expert_params(experts, params, expert_idx)
ctx.shadows[idx] = params ctx.shadows[store_idx] = params
local_output_buf, gib = fmoe_native.smart_sch_forward( local_output_buf, gib = fmoe_native.smart_sch_forward(
local_input_buf, local_input_buf,
...@@ -71,7 +74,7 @@ class MoEForward(Function): ...@@ -71,7 +74,7 @@ class MoEForward(Function):
variables = (pos_s, pos_g, local_expert_count, global_expert_count, variables = (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, gib, local_input_buf) stored_models, gib, local_input_buf)
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size ctx.moe_args = fwd_batch_size, inp.shape[0], num_expert, world_size
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return out return out
...@@ -80,23 +83,23 @@ class MoEForward(Function): ...@@ -80,23 +83,23 @@ class MoEForward(Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
(pos_s, pos_g, local_expert_count, global_expert_count, (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, _1, _2) = ctx.saved_tensors stored_models, _1, _2) = ctx.saved_tensors
(fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args (fwd_batch_size, inp_batch_size, num_expert, world_size) = ctx.moe_args
def _expert_backward(grad_y, grad_x, idx): def _expert_backward(grad_y, grad_x, expert_idx, store_idx):
y = ctx.gobs[idx] y = ctx.gobs[store_idx]
x = ctx.gibs[idx] x = ctx.gibs[store_idx]
torch.autograd.backward([y], [grad_y]) torch.autograd.backward([y], [grad_y])
grad_x.copy_(x.grad) grad_x.copy_(x.grad)
experts = ctx.experts experts = ctx.experts
def stash_fn(idx): def stash_fn(store_idx, expert_idx):
expert_utils.stash_expert_params(experts, ctx.shadows[idx]) expert_utils.stash_expert_params(experts, ctx.shadows[store_idx], expert_idx)
pop_fn = lambda: expert_utils.pop_expert_params(experts) pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx)
def collect_fn(idx, root): def collect_fn(store_idx, root, expert_idx):
grad = ctx.shadows[idx] grad = ctx.shadows[store_idx]
expert_utils.collect_expert_grads(experts, grad) expert_utils.collect_expert_grads(experts, grad, expert_idx)
fmoe_native.reduce_grad(grad, root, ctx.expert_size) fmoe_native.reduce_grad(grad, root, ctx.expert_size)
set_grad_fn = lambda idx: expert_utils.set_grads(experts, ctx.shadows[idx]) set_grad_fn = lambda store_idx, expert_idx: expert_utils.set_grads(experts, ctx.shadows[store_idx], expert_idx)
grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g) grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
grad_in_buf = fmoe_native.smart_sch_backward( grad_in_buf = fmoe_native.smart_sch_backward(
...@@ -108,7 +111,7 @@ class MoEForward(Function): ...@@ -108,7 +111,7 @@ class MoEForward(Function):
_expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn) _expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size) grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
return (None, None, grad_in, None, None, None, None, None, None, None, None) return (None, None, grad_in, None, None, None, None, None, None, None, None, None)
policy_fn = None policy_fn = None
...@@ -117,8 +120,6 @@ policy_fn = None ...@@ -117,8 +120,6 @@ policy_fn = None
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None): def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None):
# TODO: Using multiple tensors as input is to be supported. # TODO: Using multiple tensors as input is to be supported.
assert(isinstance(inp, torch.Tensor)) assert(isinstance(inp, torch.Tensor))
# TODO: Support many experts on each process
assert(n_expert == 1)
( (
pos, pos,
local_expert_count, local_expert_count,
...@@ -143,4 +144,4 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp ...@@ -143,4 +144,4 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
return MoEForward.apply(expert_fn, experts, inp, return MoEForward.apply(expert_fn, experts, inp,
torch.div(pos, topk, rounding_mode='floor'), pos, torch.div(pos, topk, rounding_mode='floor'), pos,
local_expert_count, global_expert_count, stored_models, local_expert_count, global_expert_count, stored_models,
fwd_batch_size, out_batch_size, world_size) fwd_batch_size, out_batch_size, n_expert, world_size)
...@@ -159,16 +159,24 @@ class FMoE(nn.Module): ...@@ -159,16 +159,24 @@ class FMoE(nn.Module):
if self.experts_fused: if self.experts_fused:
return self.experts(inp, fwd_expert_count) return self.experts(inp, fwd_expert_count)
if isinstance(fwd_expert_count, torch.Tensor): if isinstance(fwd_expert_count, torch.Tensor):
fwd_expert_count = fwd_expert_count.cpu().numpy() fwd_expert_count_cpu = fwd_expert_count.cpu().numpy()
outputs = [] outputs = []
base_idx = 0 base_idx = 0
for i in range(self.num_expert): for i in range(self.num_expert):
batch_size = fwd_expert_count[i] batch_size = fwd_expert_count_cpu[i]
inp_slice = inp[base_idx : base_idx + batch_size] inp_slice = inp[base_idx : base_idx + batch_size]
outputs.append(self.experts[i](inp_slice)) outputs.append(self.experts[i](inp_slice, torch.tensor([fwd_expert_count[i]])))
base_idx += batch_size base_idx += batch_size
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
def expert_fn_single(self, inp, fwd_expert_count, idx):
r"""
forward single expert for smart scheduling.
"""
assert not self.experts_fused, "should not use fused experts"
output = self.experts[idx](inp, fwd_expert_count)
return output
def mark_parallel_comm(self, expert_dp_comm="none"): def mark_parallel_comm(self, expert_dp_comm="none"):
r""" r"""
Automatically mark the data parallel comms of the parameters within the Automatically mark the data parallel comms of the parameters within the
...@@ -231,7 +239,7 @@ class FMoE(nn.Module): ...@@ -231,7 +239,7 @@ class FMoE(nn.Module):
gate_top_k_idx = gate_top_k_idx[mask == 0, :] gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward( fwd = _fmoe_general_global_forward(
moe_inp, gate_top_k_idx, self.expert_fn, moe_inp, gate_top_k_idx, self.expert_fn_single if fmoe_faster_schedule else self.expert_fn,
self.num_expert, self.world_size, self.num_expert, self.world_size,
experts=self.experts experts=self.experts
) )
......
...@@ -130,9 +130,20 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -130,9 +130,20 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used. additional numpy rng is used.
""" """
rng = np.random.default_rng(np.random.randint(2048) + self.rank) rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_megatron_init_method(self.experts.htoh4, rng, self.sigma)
if type(self.experts) is nn.ModuleList:
for expert in self.experts:
_megatron_init_method(expert.htoh4, rng, self.sigma)
else:
_megatron_init_method(self.experts.htoh4, rng, self.sigma)
std = self.sigma / math.sqrt(2.0 * self.num_layers) std = self.sigma / math.sqrt(2.0 * self.num_layers)
_megatron_init_method(self.experts.h4toh, rng, std)
if type(self.experts) is nn.ModuleList:
for expert in self.experts:
_megatron_init_method(expert.h4toh, rng, std)
else:
_megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp): def forward(self, inp):
from megatron import mpu from megatron import mpu
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .layers import FMoE from .layers import FMoE
from .linear import FMoELinear from .linear import FMoELinear
from .fastermoe.config import switch_from_env
class _Expert(nn.Module): class _Expert(nn.Module):
...@@ -47,10 +48,11 @@ class FMoETransformerMLP(FMoE): ...@@ -47,10 +48,11 @@ class FMoETransformerMLP(FMoE):
expert_rank=0, expert_rank=0,
**kwargs **kwargs
): ):
super().__init__(num_expert=num_expert, d_model=d_model, **kwargs) def one_expert(d_model):
self.experts = _Expert( return _Expert(1, d_model, d_hidden, activation, rank=0)
num_expert, d_model, d_hidden, activation, rank=expert_rank
) expert = one_expert
super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs)
self.mark_parallel_comm(expert_dp_comm) self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor): def forward(self, inp: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment