Commit 14878015 authored by zms1999's avatar zms1999
Browse files

support n_expert > 1 for FasterMoE smart scheduling and expert shadowing

parent 698a12ae
......@@ -104,7 +104,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
if (stored_models_[i]) {
torch::Tensor t = input_buf.new_empty({expert_size});
if (i / num_expert == rank) {
get_param_fn(t);
get_param_fn(t, i % num_expert);
}
params.push_back(t);
}
......
......@@ -83,7 +83,7 @@ void computePtrs(long num_expert, long rank, long world_size,
template<typename scalar_t>
void computeFn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
long idx, long offset, long micro_batch_size, long d_model,
long expert_idx, long store_idx, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) {
if(micro_batch_size == 0) {
return;
......@@ -97,7 +97,7 @@ void computeFn(py::function fn, c10::Device device,
auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options);
smgr->use_default = true;
fn(inp, oup, idx);
fn(inp, oup, expert_idx, store_idx);
smgr->use_default = false;
}
......@@ -174,7 +174,7 @@ void fmoe_cuda_fused_forward_impl(
if (i / num_expert == rank) {
cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream);
FMOE_SWE(smgr->stream(1), evt_get);
FMOE_SWE(smgr->stream(0), evt_get);
cudaEventDestroy(evt_get);
}
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
......@@ -196,7 +196,7 @@ void fmoe_cuda_fused_forward_impl(
(from_base + pipeline_gran)] - offset;
computeFn(forward_fn, device,
global_input_buf, global_output_buf,
step, offset, micro_batch_size, d_model, smgr);
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
}
......@@ -204,17 +204,17 @@ void fmoe_cuda_fused_forward_impl(
// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
stash_fn(params[si], si);
FMOE_SWE(torch_stream, evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
computeFn(forward_fn, device,
input_buf, output_buf,
n_groups + si, offset, micro_batch_size, d_model, smgr);
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
++si;
}
}
pop_fn();
pop_fn(0);
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
......@@ -319,13 +319,13 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert];
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
stash_fn(si);
stash_fn(si, 0);
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
computeFn(backward_fn, device,
grad_out, grad_in,
n_groups + si, offset, micro_batch_size, d_model, smgr);
collect_fn(si, i / num_expert);
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
......@@ -333,11 +333,11 @@ void fmoe_cuda_fused_backward_impl(
++si;
}
}
pop_fn();
pop_fn(0);
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(1), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
......@@ -346,7 +346,7 @@ void fmoe_cuda_fused_backward_impl(
computeFn(backward_fn, device,
global_grad_out, global_grad_in,
step, offset, micro_batch_size, d_model, smgr);
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
}
......@@ -356,7 +356,7 @@ void fmoe_cuda_fused_backward_impl(
if (stored_models[i]) {
if (i / num_expert == rank) {
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
set_grad_fn(si);
set_grad_fn(si, i % num_expert);
}
++si;
}
......
import torch
def get_expert_param_size(e):
def get_expert_param_size(e, idx):
e = e[idx]
return sum(map(lambda x: x.numel(), e.parameters()))
def get_expert_params(e, out):
def get_expert_params(e, out, idx):
e = e[idx]
offset = 0
for n, p in e.named_parameters():
seg = out[offset:offset + p.numel()]
......@@ -13,20 +15,25 @@ def get_expert_params(e, out):
seg.copy_(p.data.flatten())
def stash_expert_params(e, params):
def stash_expert_params(e, params, idx):
e = e[idx]
if not hasattr(e, 'expert_param_stash'):
setattr(e, 'expert_param_stash', dict())
setattr(e, 'expert_grad_stash', dict())
offset = 0
for n, p in e.named_parameters():
if n not in e.expert_param_stash:
e.expert_param_stash[n] = p.data.clone()
e.expert_grad_stash[n] = p.grad.clone() if p.grad is not None else None
with torch.no_grad():
seg = params[offset:offset + p.numel()]
offset += p.numel()
p.copy_(seg.reshape(p.shape))
p.grad = None
def pop_expert_params(e):
def pop_expert_params(e, idx):
e = e[idx]
if not hasattr(e, 'expert_param_stash'):
return
if not e.expert_param_stash:
......@@ -34,10 +41,14 @@ def pop_expert_params(e):
for n, p in e.named_parameters():
with torch.no_grad():
p.copy_(e.expert_param_stash[n])
if e.expert_grad_stash[n] is not None:
p.grad = e.expert_grad_stash[n].clone()
e.expert_param_stash.clear()
e.expert_grad_stash.clear()
def collect_expert_grads(e, grads):
def collect_expert_grads(e, grads, idx):
e = e[idx]
offset = 0
for _, p in e.named_parameters():
seg = grads[offset:offset + p.numel()]
......@@ -49,7 +60,8 @@ def collect_expert_grads(e, grads):
seg.zero_()
def set_grads(e, grads):
def set_grads(e, grads, idx):
e = e[idx]
offset = 0
for n, p in e.named_parameters():
seg = grads[offset:offset + p.numel()]
......
......@@ -23,12 +23,13 @@ class MoEForward(Function):
local_expert_count, global_expert_count,
stored_models,
fwd_batch_size, out_batch_size,
num_expert,
world_size):
local_input_buf = _local_scatter(inp, pos_s)
ctx.gibs = [None] * (world_size * 2)
ctx.gobs = [None] * (world_size * 2)
def _expert_forward(x, y, idx):
ctx.gibs = [None] * (world_size * num_expert * 2)
ctx.gobs = [None] * (world_size * num_expert * 2)
def _expert_forward(x, y, expert_idx, store_idx):
nothing = lambda a: a
x = x.data
with torch.enable_grad():
......@@ -40,22 +41,24 @@ class MoEForward(Function):
except Exception as e:
# Ignore the error and fall back for compatibility to older
# versions of PyTorch
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
ctx.gibs[idx] = x
ctx.gobs[idx] = y0
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
ctx.gibs[store_idx] = x
ctx.gobs[store_idx] = y0
y.copy_(y0)
ctx.experts = experts
if stored_models.any():
ctx.expert_size = expert_utils.get_expert_param_size(experts)
ctx.expert_size = expert_utils.get_expert_param_size(experts, 0)
for i in range(num_expert):
assert ctx.expert_size == expert_utils.get_expert_param_size(experts, i), "report bug"
else:
ctx.expert_size = 0
get_param_fn = lambda out: expert_utils.get_expert_params(experts, out)
pop_fn = lambda: expert_utils.pop_expert_params(experts)
ctx.shadows = [None] * world_size
def stash_fn(params, idx):
expert_utils.stash_expert_params(experts, params)
ctx.shadows[idx] = params
get_param_fn = lambda out, idx: expert_utils.get_expert_params(experts, out, idx)
pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx)
ctx.shadows = [None] * world_size * num_expert
def stash_fn(params, store_idx, expert_idx):
expert_utils.stash_expert_params(experts, params, expert_idx)
ctx.shadows[store_idx] = params
local_output_buf, gib = fmoe_native.smart_sch_forward(
local_input_buf,
......@@ -71,7 +74,7 @@ class MoEForward(Function):
variables = (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, gib, local_input_buf)
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
ctx.moe_args = fwd_batch_size, inp.shape[0], num_expert, world_size
ctx.save_for_backward(*variables)
return out
......@@ -80,23 +83,23 @@ class MoEForward(Function):
def backward(ctx, grad_out):
(pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, _1, _2) = ctx.saved_tensors
(fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args
(fwd_batch_size, inp_batch_size, num_expert, world_size) = ctx.moe_args
def _expert_backward(grad_y, grad_x, idx):
y = ctx.gobs[idx]
x = ctx.gibs[idx]
def _expert_backward(grad_y, grad_x, expert_idx, store_idx):
y = ctx.gobs[store_idx]
x = ctx.gibs[store_idx]
torch.autograd.backward([y], [grad_y])
grad_x.copy_(x.grad)
experts = ctx.experts
def stash_fn(idx):
expert_utils.stash_expert_params(experts, ctx.shadows[idx])
pop_fn = lambda: expert_utils.pop_expert_params(experts)
def collect_fn(idx, root):
grad = ctx.shadows[idx]
expert_utils.collect_expert_grads(experts, grad)
def stash_fn(store_idx, expert_idx):
expert_utils.stash_expert_params(experts, ctx.shadows[store_idx], expert_idx)
pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx)
def collect_fn(store_idx, root, expert_idx):
grad = ctx.shadows[store_idx]
expert_utils.collect_expert_grads(experts, grad, expert_idx)
fmoe_native.reduce_grad(grad, root, ctx.expert_size)
set_grad_fn = lambda idx: expert_utils.set_grads(experts, ctx.shadows[idx])
set_grad_fn = lambda store_idx, expert_idx: expert_utils.set_grads(experts, ctx.shadows[store_idx], expert_idx)
grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
grad_in_buf = fmoe_native.smart_sch_backward(
......@@ -108,7 +111,7 @@ class MoEForward(Function):
_expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
return (None, None, grad_in, None, None, None, None, None, None, None, None)
return (None, None, grad_in, None, None, None, None, None, None, None, None, None)
policy_fn = None
......@@ -117,8 +120,6 @@ policy_fn = None
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None):
# TODO: Using multiple tensors as input is to be supported.
assert(isinstance(inp, torch.Tensor))
# TODO: Support many experts on each process
assert(n_expert == 1)
(
pos,
local_expert_count,
......@@ -143,4 +144,4 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
return MoEForward.apply(expert_fn, experts, inp,
torch.div(pos, topk, rounding_mode='floor'), pos,
local_expert_count, global_expert_count, stored_models,
fwd_batch_size, out_batch_size, world_size)
fwd_batch_size, out_batch_size, n_expert, world_size)
......@@ -159,16 +159,24 @@ class FMoE(nn.Module):
if self.experts_fused:
return self.experts(inp, fwd_expert_count)
if isinstance(fwd_expert_count, torch.Tensor):
fwd_expert_count = fwd_expert_count.cpu().numpy()
fwd_expert_count_cpu = fwd_expert_count.cpu().numpy()
outputs = []
base_idx = 0
for i in range(self.num_expert):
batch_size = fwd_expert_count[i]
batch_size = fwd_expert_count_cpu[i]
inp_slice = inp[base_idx : base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
outputs.append(self.experts[i](inp_slice, torch.tensor([fwd_expert_count[i]])))
base_idx += batch_size
return torch.cat(outputs, dim=0)
def expert_fn_single(self, inp, fwd_expert_count, idx):
r"""
forward single expert for smart scheduling.
"""
assert not self.experts_fused, "should not use fused experts"
output = self.experts[idx](inp, fwd_expert_count)
return output
def mark_parallel_comm(self, expert_dp_comm="none"):
r"""
Automatically mark the data parallel comms of the parameters within the
......@@ -231,7 +239,7 @@ class FMoE(nn.Module):
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward(
moe_inp, gate_top_k_idx, self.expert_fn,
moe_inp, gate_top_k_idx, self.expert_fn_single if fmoe_faster_schedule else self.expert_fn,
self.num_expert, self.world_size,
experts=self.experts
)
......
......@@ -130,9 +130,20 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used.
"""
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_megatron_init_method(self.experts.htoh4, rng, self.sigma)
if type(self.experts) is nn.ModuleList:
for expert in self.experts:
_megatron_init_method(expert.htoh4, rng, self.sigma)
else:
_megatron_init_method(self.experts.htoh4, rng, self.sigma)
std = self.sigma / math.sqrt(2.0 * self.num_layers)
_megatron_init_method(self.experts.h4toh, rng, std)
if type(self.experts) is nn.ModuleList:
for expert in self.experts:
_megatron_init_method(expert.h4toh, rng, std)
else:
_megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp):
from megatron import mpu
......
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
from .layers import FMoE
from .linear import FMoELinear
from .fastermoe.config import switch_from_env
class _Expert(nn.Module):
......@@ -47,10 +48,11 @@ class FMoETransformerMLP(FMoE):
expert_rank=0,
**kwargs
):
super().__init__(num_expert=num_expert, d_model=d_model, **kwargs)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=expert_rank
)
def one_expert(d_model):
return _Expert(1, d_model, d_hidden, activation, rank=0)
expert = one_expert
super().__init__(num_expert=num_expert, d_model=d_model, expert=expert, **kwargs)
self.mark_parallel_comm(expert_dp_comm)
def forward(self, inp: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment