Commit 794dd0e6 authored by Rick Ho's avatar Rick Ho
Browse files

expert shadow backward with test

parent b5b72d41
......@@ -19,6 +19,51 @@ void setSmartSchEnabled(int s) {
smart_sch_enabled = s;
}
inline ncclDataType_t getNcclDataType(at::ScalarType t) {
switch (t) {
case at::kChar: return ncclInt8;
case at::kByte: return ncclUint8;
case at::kFloat: return ncclFloat;
case at::kDouble: return ncclDouble;
case at::kInt: return ncclInt32;
case at::kLong: return ncclInt64;
case at::kHalf: return ncclHalf;
case at::kBool: return ncclUint8;
#if defined(ENABLE_NCCL_BF16_DATATYPE)
case at::kBFloat16: return ncclBfloat16;
#endif
default: return ncclChar;
}
}
void _reduce_grad(
torch::Tensor t,
long root,
long expert_size) {
auto smgr = getCudaStreamManager(t.device().index());
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream);
cudaStreamWaitEvent(smgr->stream(0), evt_stash, 0);
cudaEventDestroy(evt_stash);
auto dtype = getNcclDataType(t.scalar_type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(t.scalar_type(),
"fmoe_cuda_reduce_grad", ([&] {
void* buf = (void*)t.data_ptr<scalar_t>();
NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size,
dtype,
ncclSum, root,
smgr->ncclcomm, smgr->stream(0)));
})
);
}
std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
......@@ -51,7 +96,6 @@ std::vector<torch::Tensor> _smart_sch_forward(
// TODO: maybe empty is faster
auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});
std::vector<torch::Tensor> params;
......@@ -66,7 +110,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_smart_sch_forward", ([&] {
fmoe_cuda_fused_forward_impl(
forward_fn,
......@@ -96,7 +140,6 @@ torch::Tensor _smart_sch_backward(
torch::Tensor stored_models,
long buf_batch_size,
long global_batch_size,
long expert_size,
long n_workers,
py::function backward_fn,
py::function stash_fn,
......@@ -112,10 +155,14 @@ torch::Tensor _smart_sch_backward(
auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
"fmoe_cuda_smartsch_backward", ([&] {
fmoe_cuda_fused_backward_impl(
backward_fn,
stash_fn,
pop_fn,
collect_fn,
set_grad_fn,
grad_out.device(),
grad_out.data_ptr<scalar_t>(),
......@@ -129,7 +176,7 @@ torch::Tensor _smart_sch_backward(
d_model, num_expert, rank, n_workers,
pipeline_gran, smgr);
}));
return {grad_in,};
return grad_in;
}
#endif
......@@ -13,7 +13,7 @@
template<typename scalar_t>
void _exchange_with(
void exchangeWith(
const scalar_t* sendbuf, size_t sendcount, int t_send,
scalar_t* recvbuf, size_t recvcount, int t_recv,
long d_model,
......@@ -40,15 +40,15 @@ void _exchange_with(
int idx_self = ei + rank * num_expert;
void _compute_ptrs(long num_expert, long rank, long world_size,
const long* local_expert_count,
const long* global_expert_count,
void computePtrs(long num_expert, long rank, long world_size,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
int *local_ptr,
int *global_ptr,
int *local_global_ptr) {
local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
for (int i = 0; i < num_expert * world_size; ++i) {
local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];
......@@ -76,7 +76,7 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
template<typename scalar_t>
void _compute_fn(py::function fn, c10::Device device,
void computeFn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
long idx, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) {
......@@ -107,8 +107,8 @@ void fmoe_cuda_fused_forward_impl(
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
long d_model,
......@@ -119,7 +119,7 @@ void fmoe_cuda_fused_forward_impl(
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
_compute_ptrs(num_expert, rank, world_size,
computePtrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr);
......@@ -145,7 +145,7 @@ void fmoe_cuda_fused_forward_impl(
int rank_send = j + to_base;
int rank_recv = j + from_base;
GEN_IDX;
_exchange_with(input_buf + local_ptr[idx_send] * d_model,
exchangeWith(input_buf + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
......@@ -167,8 +167,9 @@ void fmoe_cuda_fused_forward_impl(
cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream);
cudaStreamWaitEvent(smgr->stream(1), evt_get);
cudaEventDestroy(evt_get);
}
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0)));
cudaEventCreate(evt_shadow + si);
......@@ -183,10 +184,9 @@ void fmoe_cuda_fused_forward_impl(
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size +
long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset;
_compute_fn(forward_fn, device,
computeFn(forward_fn, device,
global_input_buf, global_output_buf,
step, offset, micro_batch_size, d_model, smgr);
}
......@@ -200,7 +200,7 @@ void fmoe_cuda_fused_forward_impl(
cudaStreamWaitEvent(torch_stream, evt_shadow[si], 0);
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
_compute_fn(forward_fn, device,
computeFn(forward_fn, device,
input_buf, output_buf,
n_groups + si, offset, micro_batch_size, d_model, smgr);
++si;
......@@ -218,7 +218,7 @@ void fmoe_cuda_fused_forward_impl(
int rank_send = j + from_base;
int rank_recv = j + to_base;
GEN_IDX;
_exchange_with(global_output_buf + global_ptr[gidx_send] * d_model,
exchangeWith(global_output_buf + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
......@@ -241,15 +241,16 @@ void fmoe_cuda_fused_forward_impl(
}
delete [] input_ready;
delete [] output_ready;
if (params.size()) {
delete [] evt_shadow;
}
}
template<typename scalar_t>
void fmoe_cuda_fused_backward_impl(
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn,
c10::Device device,
scalar_t* grad_out,
......@@ -257,8 +258,8 @@ void fmoe_cuda_fused_backward_impl(
scalar_t* global_grad_in,
scalar_t* grad_in,
const long* local_expert_count,
const long* global_expert_count,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
long d_model,
long num_expert, long rank, long world_size,
......@@ -269,10 +270,9 @@ void fmoe_cuda_fused_backward_impl(
int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
_compute_ptrs(num_expert, rank, world_size,
computePtrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr);
if (pipeline_gran > world_size) {
pipeline_gran = world_size;
}
......@@ -286,6 +286,7 @@ void fmoe_cuda_fused_backward_impl(
cudaEventCreate(output_ready + i);
}
// S_0 ... S_n
for (long step = 0; step < n_groups; ++step) {
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
......@@ -294,7 +295,7 @@ void fmoe_cuda_fused_backward_impl(
int rank_send = j + to_base;
int rank_recv = j + from_base;
GEN_IDX;
_exchange_with(grad_out + local_ptr[idx_send] * d_model,
exchangeWith(grad_out + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
......@@ -305,21 +306,54 @@ void fmoe_cuda_fused_backward_impl(
cudaEventRecord(input_ready[step], smgr->stream(0));
}
// Shadowed experts backward and reduce
cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert];
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
stash_fn(si);
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
computeFn(backward_fn, device,
grad_out, grad_in,
n_groups + si, offset, micro_batch_size, d_model, smgr);
collect_fn(si, i / num_expert);
if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
}
++si;
}
}
pop_fn();
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size +
long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset;
_compute_fn(backward_fn, device,
computeFn(backward_fn, device,
global_grad_out, global_grad_in,
step, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
}
// Collect gradients for shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
cudaStreamWaitEvent(torch_stream, evt_reduce[i % num_expert], 0);
set_grad_fn(si);
}
++si;
}
}
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
......@@ -329,7 +363,7 @@ void fmoe_cuda_fused_backward_impl(
int rank_send = j + from_base;
int rank_recv = j + to_base;
GEN_IDX;
_exchange_with(global_grad_in + global_ptr[gidx_send] * d_model,
exchangeWith(global_grad_in + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
......@@ -341,36 +375,6 @@ void fmoe_cuda_fused_backward_impl(
checkCudaErrors(cudaGetLastError());
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
grad_weight1 = params[j][0][0].mutable_grad().data_ptr<scalar_t>();
grad_weight2 = params[j][0][last].mutable_grad().data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_backward(
original_input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf, grad_out + local_ptr[idx] * d_model,
grad_middle + (offset + local_global_ptr[idx]) * d_hidden, grad_weight1, grad_weight2, grad_in + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden, 0, // we never consider it to be the first since it's already initialized to zero and we are lazy
smgr->stream(stream), smgr->handle(stream));
}
}
*/
delete [] local_ptr;
delete [] global_ptr;
delete [] local_global_ptr;
......@@ -381,6 +385,12 @@ void fmoe_cuda_fused_backward_impl(
}
delete [] input_ready;
delete [] output_ready;
for (long i = 0; i < num_expert; ++i) {
if (stored_models[i + rank * num_expert]) {
cudaEventDestroy(evt_reduce[i]);
}
}
delete [] evt_reduce;
}
#endif // SMART_SCHEDULE_H
......@@ -77,13 +77,16 @@ torch::Tensor _smart_sch_backward(
torch::Tensor stored_models,
long buf_batch_size,
long global_batch_size,
long expert_size,
long n_workers,
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn);
void _reduce_grad(
torch::Tensor t,
long root,
long expert_size);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL
......@@ -95,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling");
m.def("smart_sch_backward", &_smart_sch_backward, "E2E MoE layer backward with smart scheduling");
m.def("reduce_grad", &_reduce_grad, "Reduce gradients over FastMoE's communication stream");
#endif
m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
......
......@@ -6,7 +6,6 @@ def get_expert_param_size(e):
def get_expert_params(e, out):
print('gep to {}'.format(out))
offset = 0
for n, p in e.named_parameters():
seg = out[offset:offset + p.numel()]
......@@ -42,7 +41,7 @@ def collect_expert_grads(e, grads):
seg = grads[offset:offset + p.numel()]
offset += p.numel()
if p.grad is not None:
seg.copy_(p.grad)
seg.copy_(p.grad.flatten())
p.grad = None
else:
seg.zero_()
......@@ -56,4 +55,4 @@ def set_grads(e, grads):
if p.grad is None:
p.grad = seg.clone()
else:
p.grad += seg
p.grad += seg.reshape(p.shape)
......@@ -24,16 +24,15 @@ class MoEForward(Function):
world_size):
local_input_buf = _local_scatter(inp, pos_s)
# TODO: leave this for furture work of expert shadowing
# model_params = [[tuple(m.parameters()) for m in node] for node in models]
ctx.gibs = [None] * (world_size * 2)
ctx.gobs = [None] * (world_size * 2)
def _expert_forward(x, y, idx):
nothing = lambda a: a
x = x.data
with torch.enable_grad():
x.requires_grad = True
y0 = expert_fn(x, [x.shape[0]])
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, [x.shape[0]])
ctx.gibs[idx] = x
ctx.gobs[idx] = y0
y.copy_(y0)
......@@ -60,7 +59,7 @@ class MoEForward(Function):
maybe_overlap=False)
variables = (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, gib)
stored_models, gib, local_input_buf)
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
ctx.save_for_backward(*variables)
......@@ -70,30 +69,33 @@ class MoEForward(Function):
@staticmethod
def backward(ctx, grad_out):
(pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, _) = ctx.saved_tensors
stored_models, _1, _2) = ctx.saved_tensors
(fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args
def _expert_backward(grad_y, grad_x, idx):
y = ctx.gobs[idx]
torch.autograd.backward([y], [grad_y])
x = ctx.gibs[idx]
torch.autograd.backward([y], [grad_y])
grad_x.copy_(x.grad)
experts = ctx.experts
def stash_fn(idx):
expert_utils.stash_expert_params(experts, ctx.shadows[idx])
pop_fn = lambda: expert_utils.pop_expert_params(experts)
collect_fn = lambda g: expert_utils.collect_expert_grads(experts, g)
set_grad_fn = lambda g: expert_utils.set_grads(experts, g)
def collect_fn(idx, root):
grad = ctx.shadows[idx]
expert_utils.collect_expert_grads(experts, grad)
fmoe_native.reduce_grad(grad, root, ctx.expert_size)
set_grad_fn = lambda idx: expert_utils.set_grads(experts, ctx.shadows[idx])
grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
grad_in_buf = fmoe_native.smart_sch_backward(
grad_out_buf,
local_expert_count, global_expert_count,
stored_models,
pos_s.shape[0], fwd_batch_size, ctx.expert_size,
world_size, _expert_backward,
stash_fn, pop_fn, collect_fn, set_grad_fn)
pos_s.shape[0], fwd_batch_size,
world_size,
_expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
return (None, None, grad_in, None, None, None, None, None, None, None, None)
......
......@@ -62,16 +62,18 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
dist.broadcast(stored_models, 0)
stored_models = stored_models.cpu()
# if rank == 0:
# print('stored models {}'.format(stored_models))
ensure_comm(x1, None)
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1, stored_models=stored_models)
# y1.sum().backward()
y1.sum().backward()
y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2)
# y2.sum().backward()
_assert_numerical(['out'], [y1], [y2], rank)
# _assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
# [y1, x1.grad, m1.bias.grad, m1.weight.grad],
# [y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
y2.sum().backward()
_assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
[y1, x1.grad, m1.bias.grad, m1.weight.grad],
[y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment