Commit a807e2a3 authored by Rick Ho's avatar Rick Ho
Browse files

backward bugous on grad weight

parent 46c3722d
...@@ -17,7 +17,6 @@ torch::Tensor _smart_sch_forward( ...@@ -17,7 +17,6 @@ torch::Tensor _smart_sch_forward(
long global_batch_size, long global_batch_size,
long n_workers, long n_workers,
py::function forward_fn) { py::function forward_fn) {
if (pipeline_gran == -1) { if (pipeline_gran == -1) {
char* p = getenv("FMOE_FASTER_GROUP_SIZE"); char* p = getenv("FMOE_FASTER_GROUP_SIZE");
if (p) { if (p) {
...@@ -40,7 +39,7 @@ torch::Tensor _smart_sch_forward( ...@@ -40,7 +39,7 @@ torch::Tensor _smart_sch_forward(
auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model}); auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_fused_forward", ([&] { "fmoe_cuda_smart_sch_forward", ([&] {
fmoe_cuda_fused_forward_impl( fmoe_cuda_fused_forward_impl(
forward_fn, forward_fn,
input_buf.device(), input_buf.device(),
...@@ -59,75 +58,42 @@ torch::Tensor _smart_sch_forward( ...@@ -59,75 +58,42 @@ torch::Tensor _smart_sch_forward(
return output_buf; return output_buf;
} }
/* torch::Tensor _smart_sch_backward(
std::vector<torch::Tensor> _fused_backward(
torch::Tensor input_buf,
std::vector<std::vector<std::vector<torch::Tensor>>> params,
torch::Tensor middle_buf,
torch::Tensor output_buf,
torch::Tensor grad_out, torch::Tensor grad_out,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
torch::Tensor inp,
torch::Tensor stored_models, torch::Tensor stored_models,
long global_batch_size,
long buf_batch_size, long buf_batch_size,
long n_workers, bool has_bias) { long global_batch_size,
long n_workers,
py::function backward_fn) {
const auto num_expert = local_expert_count.size(0) / n_workers; const auto num_expert = local_expert_count.size(0) / n_workers;
auto smgr = getCudaStreamManager(grad_out.device().index());
auto smgr = getCudaStreamManager(input_buf.device().index());
int rank; int rank;
ncclCommUserRank(smgr->ncclcomm, &rank); ncclCommUserRank(smgr->ncclcomm, &rank);
const auto d_model = grad_out.size(1);
const auto d_hidden = params[rank][0][0].size(1); auto global_grad_out = grad_out.new_zeros({global_batch_size, d_model});
const auto d_model = params[rank][0][0].size(2); auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
auto global_grad_out = input_buf.new_zeros({global_batch_size, d_model}); "fmoe_cuda_smartsch_backward", ([&] {
auto grad_middle = input_buf.new_zeros({global_batch_size, d_hidden});
auto global_grad_in = input_buf.new_zeros({global_batch_size, d_model});
auto grad_in = input_buf.new_zeros({buf_batch_size, d_model});
for (auto node : params)
for (auto expert : node)
for (int i = 0; i < expert.size(); i++) {
// create the respective gradient of each tensor
CHECK_INPUT(expert[i]);
if (expert[i].grad().defined()) {
CHECK_INPUT(expert[i].grad());
continue;
}
expert[i].mutable_grad() = input_buf.new_zeros(expert[i].sizes());
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_fused_backward", ([&] {
fmoe_cuda_fused_backward_impl( fmoe_cuda_fused_backward_impl(
input_buf.data_ptr<scalar_t>(), backward_fn,
inp.data_ptr<scalar_t>(), grad_out.device(),
params,
middle_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
grad_out.data_ptr<scalar_t>(), grad_out.data_ptr<scalar_t>(),
global_grad_out.data_ptr<scalar_t>(), global_grad_out.data_ptr<scalar_t>(),
global_grad_in.data_ptr<scalar_t>(), global_grad_in.data_ptr<scalar_t>(),
grad_middle.data_ptr<scalar_t>(),
grad_in.data_ptr<scalar_t>(), grad_in.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(), local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(), global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(), stored_models.data_ptr<bool>(),
d_model, d_hidden, num_expert, rank, n_workers, has_bias, d_model, num_expert, rank, n_workers,
pipeline_gran, smgr); pipeline_gran, smgr);
})); }));
return {grad_in,}; return {grad_in,};
} }
*/
#endif #endif
...@@ -74,7 +74,7 @@ void _compute_ptrs(long num_expert, long rank, long world_size, ...@@ -74,7 +74,7 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
} }
template<typename scalar_t> template<typename scalar_t>
void _compute_forward(py::function fn, c10::Device device, void _compute_fn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf, scalar_t* inp_buf, scalar_t* out_buf,
int ei, long step, long offset, long micro_batch_size, long d_model) { int ei, long step, long offset, long micro_batch_size, long d_model) {
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
...@@ -89,14 +89,6 @@ void _compute_forward(py::function fn, c10::Device device, ...@@ -89,14 +89,6 @@ void _compute_forward(py::function fn, c10::Device device,
} }
template<typename scalar_t>
void _compute_backward(py::function fn,
scalar_t* inp_buf, scalar_t* out_buf,
long* local_expert_count, long* global_expert_count,
int ei, long offset, long micro_batch_size, long d_model) {
}
template<typename scalar_t> template<typename scalar_t>
void fmoe_cuda_fused_forward_impl( void fmoe_cuda_fused_forward_impl(
py::function forward_fn, py::function forward_fn,
...@@ -162,7 +154,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -162,7 +154,7 @@ void fmoe_cuda_fused_forward_impl(
long micro_batch_size = global_ptr[ei * world_size + long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset; (from_base + pipeline_gran)] - offset;
_compute_forward(forward_fn, device, _compute_fn(forward_fn, device,
global_input_buf, global_output_buf, global_input_buf, global_output_buf,
ei, step, offset, micro_batch_size, d_model); ei, step, offset, micro_batch_size, d_model);
} }
...@@ -230,19 +222,17 @@ void fmoe_cuda_fused_forward_impl( ...@@ -230,19 +222,17 @@ void fmoe_cuda_fused_forward_impl(
template<typename scalar_t> template<typename scalar_t>
void fmoe_cuda_fused_backward_impl( void fmoe_cuda_fused_backward_impl(
py::function backward_fn, py::function backward_fn,
const scalar_t* input_buf, c10::Device device,
const scalar_t* output_buf,
const scalar_t* grad_out,
scalar_t* grad_out,
scalar_t* global_grad_out, scalar_t* global_grad_out,
scalar_t* global_grad_in, scalar_t* global_grad_in,
scalar_t* grad_in, scalar_t* grad_in,
const long* local_expert_count, const long* local_expert_count,
const long* global_expert_count, const long* global_expert_count,
const bool* stored_models, const bool* stored_models,
long d_model, long d_hidden, long d_model,
long num_expert, long rank, long world_size, long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
...@@ -294,11 +284,9 @@ void fmoe_cuda_fused_backward_impl( ...@@ -294,11 +284,9 @@ void fmoe_cuda_fused_backward_impl(
long micro_batch_size = global_ptr[ei * world_size + long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset; (from_base + pipeline_gran)] - offset;
_compute_backward(backward_fn, _compute_fn(backward_fn, device,
input_buf, output_buf, global_grad_out, global_grad_out, global_grad_in,
global_grad_in, ei, step, offset, micro_batch_size, d_model);
ei,
offset, micro_batch_size);
} }
// TODO: get pytorch's compute stream // TODO: get pytorch's compute stream
} }
......
...@@ -65,6 +65,15 @@ torch::Tensor _smart_sch_forward( ...@@ -65,6 +65,15 @@ torch::Tensor _smart_sch_forward(
torch::Tensor stored_models, torch::Tensor stored_models,
long global_batch_size, long n_workers, long global_batch_size, long n_workers,
py::function forward_fn); py::function forward_fn);
torch::Tensor _smart_sch_backward(
torch::Tensor grad_out,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long buf_batch_size,
long global_batch_size,
long n_workers,
py::function backward_fn);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
...@@ -75,6 +84,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -75,6 +84,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)"); m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)");
m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling"); m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling");
m.def("smart_sch_backward", &_smart_sch_backward, "E2E MoE layer backward with smart scheduling");
#endif #endif
m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)"); m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
......
...@@ -29,8 +29,9 @@ class MoEForward(Function): ...@@ -29,8 +29,9 @@ class MoEForward(Function):
ctx.gobs = [None] * world_size ctx.gobs = [None] * world_size
def _expert_forward(x, y, idx): def _expert_forward(x, y, idx):
x = x.data x = x.data
x.requires_grad = True with torch.enable_grad():
y0 = expert_fn(x, [x.shape[0]]) x.requires_grad = True
y0 = expert_fn(x, [x.shape[0]])
ctx.gibs[idx] = x ctx.gibs[idx] = x
ctx.gobs[idx] = y0 ctx.gobs[idx] = y0
y.copy_(y0) y.copy_(y0)
...@@ -55,21 +56,21 @@ class MoEForward(Function): ...@@ -55,21 +56,21 @@ class MoEForward(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(pos_s, pos_g, local_expert_count, global_expert_count, (pos_s, pos_g, local_expert_count, global_expert_count,
gib, gmb, gob, stored_models) = ctx.saved_tensors stored_models) = ctx.saved_tensors
(fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args (fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args
def _expert_backward(grad, idx): def _expert_backward(grad_y, grad_x, idx):
y = ctx.gobs[idx] y = ctx.gobs[idx]
torch.autograd.backward([y], [grad]) torch.autograd.backward([y], [grad_y])
x = ctx.gibs[idx] x = ctx.gibs[idx]
return x.grad grad_x.copy_(x.grad)
grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g) grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
grad_in_buf = fmoe_native.smart_sch_backward( grad_in_buf = fmoe_native.smart_sch_backward(
gib, gmb, gob, grad_out_buf, grad_out_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
stored_models, stored_models,
fwd_batch_size, pos_s.shape[0], pos_s.shape[0], fwd_batch_size,
world_size, _expert_backward) world_size, _expert_backward)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size) grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
......
...@@ -19,7 +19,8 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd ...@@ -19,7 +19,8 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@pytest.mark.parametrize("d_model", [1024]) @pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1]) @pytest.mark.parametrize("n_expert", [1])
def test_faster_schedule(n_process, d_model, batch_size, n_expert): @pytest.mark.parametrize("group_sz", [1, 2, 4])
def test_faster_schedule(n_process, d_model, batch_size, n_expert, group_sz):
_run_distributed('_test_faster_schedule', _run_distributed('_test_faster_schedule',
n_process, n_process,
{ {
...@@ -28,7 +29,9 @@ def test_faster_schedule(n_process, d_model, batch_size, n_expert): ...@@ -28,7 +29,9 @@ def test_faster_schedule(n_process, d_model, batch_size, n_expert):
'n_expert': n_expert 'n_expert': n_expert
}, },
script=__file__, script=__file__,
env=dict() env=dict(
FMOE_FASTER_GROUP_SIZE=str(group_sz)
)
) )
...@@ -37,19 +40,33 @@ def _test_faster_schedule(d_model, batch_size, n_expert): ...@@ -37,19 +40,33 @@ def _test_faster_schedule(d_model, batch_size, n_expert):
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
x = torch.rand(batch_size, d_model).cuda() x1 = torch.rand(batch_size, d_model).cuda()
x.requires_grad = True x1.requires_grad = True
x2 = x1.data.clone()
x2.requires_grad = True
topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda() topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda()
m = torch.nn.Linear(d_model, d_model).cuda() m1 = torch.nn.Linear(d_model, d_model).cuda()
m2 = torch.nn.Linear(d_model, d_model).cuda()
with torch.no_grad():
m2.weight.copy_(m1.weight)
m2.bias.copy_(m1.bias)
def expert_fn(x, fec): def ef1(x, fec):
y = m(x) y = m1(x)
return y return y
def ef2(x, fec):
y = m2(x)
return y
ensure_comm(x1, None)
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size)
y1.sum().backward()
ensure_comm(x, None) y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size)
y = smart_fwd(x, topk_idx, expert_fn, n_expert, world_size) y2.sum().backward()
z = naive_fwd(x, topk_idx, expert_fn, n_expert, world_size) _assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
_assert_numerical(['out'], [y], [z], rank) [y1, x1.grad, m1.bias.grad, m1.weight.grad],
[y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -57,5 +74,5 @@ if __name__ == '__main__': ...@@ -57,5 +74,5 @@ if __name__ == '__main__':
args = json.loads(sys.argv[2]) args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args) locals()[sys.argv[1]](**args)
else: else:
# test_faster_schedule(8, 16, 16, 1) # test_faster_schedule(8, 16, 16, 1, 2)
_test_faster_schedule(4, 2, 1) _test_faster_schedule(4, 2, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment