Commit b5b72d41 authored by Rick Ho's avatar Rick Ho
Browse files

forward tested

parent 771dc62d
...@@ -168,7 +168,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -168,7 +168,7 @@ void fmoe_cuda_fused_forward_impl(
cudaEventRecord(evt_get, torch_stream); cudaEventRecord(evt_get, torch_stream);
cudaStreamWaitEvent(smgr->stream(1), evt_get); cudaStreamWaitEvent(smgr->stream(1), evt_get);
} }
NCCL_SAFE_CALL(ncclBcast(params[si].data_ptr<void>(), NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar, expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0))); i / num_expert, smgr->ncclcomm, smgr->stream(0)));
cudaEventCreate(evt_shadow + si); cudaEventCreate(evt_shadow + si);
......
...@@ -80,6 +80,8 @@ torch::Tensor _smart_sch_backward( ...@@ -80,6 +80,8 @@ torch::Tensor _smart_sch_backward(
long expert_size, long expert_size,
long n_workers, long n_workers,
py::function backward_fn, py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn, py::function collect_fn,
py::function set_grad_fn); py::function set_grad_fn);
......
...@@ -6,11 +6,12 @@ def get_expert_param_size(e): ...@@ -6,11 +6,12 @@ def get_expert_param_size(e):
def get_expert_params(e, out): def get_expert_params(e, out):
print('gep to {}'.format(out))
offset = 0 offset = 0
for n, p in e.named_parameters(): for n, p in e.named_parameters():
seg = out[offset:offset + p.numel()] seg = out[offset:offset + p.numel()]
offset += p.numel() offset += p.numel()
seg.copy_(p) seg.copy_(p.data.flatten())
def stash_expert_params(e, params): def stash_expert_params(e, params):
...@@ -27,6 +28,8 @@ def stash_expert_params(e, params): ...@@ -27,6 +28,8 @@ def stash_expert_params(e, params):
def pop_expert_params(e): def pop_expert_params(e):
if not hasattr(e, 'expert_param_stash'):
return
for n, p in e.named_parameters(): for n, p in e.named_parameters():
with torch.no_grad(): with torch.no_grad():
p.copy_(e.expert_param_stash[n]) p.copy_(e.expert_param_stash[n])
......
...@@ -7,7 +7,7 @@ from torch.autograd.function import Function ...@@ -7,7 +7,7 @@ from torch.autograd.function import Function
from fmoe.functions import prepare_forward, ensure_comm from fmoe.functions import prepare_forward, ensure_comm
from fmoe.functions import _local_scatter, _local_gather from fmoe.functions import _local_scatter, _local_gather
import fmoe_cuda as fmoe_native import fmoe_cuda as fmoe_native
import expert_utils from fmoe.fastermoe import expert_utils
class MoEForward(Function): class MoEForward(Function):
...@@ -47,7 +47,7 @@ class MoEForward(Function): ...@@ -47,7 +47,7 @@ class MoEForward(Function):
pop_fn = lambda: expert_utils.pop_expert_params(experts) pop_fn = lambda: expert_utils.pop_expert_params(experts)
ctx.shadows = [None] * world_size ctx.shadows = [None] * world_size
def stash_fn(params, idx): def stash_fn(params, idx):
expert_utils.stash_expert_params(experts, p) expert_utils.stash_expert_params(experts, params)
ctx.shadows[idx] = params ctx.shadows[idx] = params
local_output_buf, gib = fmoe_native.smart_sch_forward( local_output_buf, gib = fmoe_native.smart_sch_forward(
...@@ -99,7 +99,7 @@ class MoEForward(Function): ...@@ -99,7 +99,7 @@ class MoEForward(Function):
return (None, None, grad_in, None, None, None, None, None, None, None, None) return (None, None, grad_in, None, None, None, None, None, None, None, None)
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None): def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None):
# TODO: Using multiple tensors as input is to be supported. # TODO: Using multiple tensors as input is to be supported.
assert(isinstance(inp, torch.Tensor)) assert(isinstance(inp, torch.Tensor))
# TODO: Support many experts on each process # TODO: Support many experts on each process
...@@ -113,7 +113,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp ...@@ -113,7 +113,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
) = prepare_forward(gate, n_expert, world_size) ) = prepare_forward(gate, n_expert, world_size)
# TODO: Expert shadowing is to be supported. Currently using all 0s # TODO: Expert shadowing is to be supported. Currently using all 0s
stored_models = torch.zeros(n_expert * world_size, dtype=torch.bool) if stored_models is None:
stored_models = torch.zeros(n_expert * world_size, dtype=torch.bool)
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.functions import ensure_comm
from test_ddp import _ensure_initialized, _run_distributed
from test_numerical import _assert_numerical
from fmoe.fastermoe.schedule import _fmoe_general_global_forward as smart_fwd
from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1])
@pytest.mark.parametrize("group_sz", [1, 2, 4])
def test_faster_shadow(n_process, d_model, batch_size, n_expert, group_sz):
_run_distributed('_test_faster_shadow',
n_process,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert
},
script=__file__,
env=dict(
FMOE_FASTER_GROUP_SIZE=str(group_sz)
)
)
def _test_faster_shadow(d_model, batch_size, n_expert):
_ensure_initialized()
rank = dist.get_rank()
world_size = dist.get_world_size()
x1 = torch.rand(batch_size, d_model).cuda()
x1.requires_grad = True
x2 = x1.data.clone()
x2.requires_grad = True
topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda()
m1 = torch.nn.Linear(d_model, d_model).cuda()
m2 = torch.nn.Linear(d_model, d_model).cuda()
with torch.no_grad():
m2.weight.copy_(m1.weight)
m2.bias.copy_(m1.bias)
def ef1(x, fec):
y = m1(x)
return y
def ef2(x, fec):
y = m2(x)
return y
stored_models = torch.randint(0, 2, (world_size,)).bool().cuda()
dist.broadcast(stored_models, 0)
stored_models = stored_models.cpu()
ensure_comm(x1, None)
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1, stored_models=stored_models)
# y1.sum().backward()
y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2)
# y2.sum().backward()
_assert_numerical(['out'], [y1], [y2], rank)
# _assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
# [y1, x1.grad, m1.bias.grad, m1.weight.grad],
# [y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
# test_faster_shadow(8, 16, 16, 1, 2)
_test_faster_shadow(4, 2, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment