Commit 864a4522 authored by Rick Ho's avatar Rick Ho
Browse files

multi-gpu forward pass test

parent 069cf01a
...@@ -27,15 +27,17 @@ class MOELayer(nn.Module): ...@@ -27,15 +27,17 @@ class MOELayer(nn.Module):
class MOELayer_raw(nn.Module): class MOELayer_raw(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
super(MOELayer_raw, self).__init__() super(MOELayer_raw, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat)) torch.Tensor(num_expert * world_size, out_feat, in_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, linear = nn.Linear(in_features=self.in_feat,
......
...@@ -155,6 +155,7 @@ void moe_cuda_global_scatter_impl( ...@@ -155,6 +155,7 @@ void moe_cuda_global_scatter_impl(
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
delete [] expert_ptr; delete [] expert_ptr;
smgr->sync(1);
} }
std::vector<torch::Tensor> moe_cuda_global_scatter( std::vector<torch::Tensor> moe_cuda_global_scatter(
...@@ -224,6 +225,7 @@ void moe_cuda_global_gather_impl( ...@@ -224,6 +225,7 @@ void moe_cuda_global_gather_impl(
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
delete [] expert_ptr; delete [] expert_ptr;
smgr->sync(1);
} }
std::vector<torch::Tensor> moe_cuda_global_gather( std::vector<torch::Tensor> moe_cuda_global_gather(
...@@ -238,7 +240,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather( ...@@ -238,7 +240,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] { "moe_cuda_global_gather", ([&] {
moe_cuda_global_scatter_impl<scalar_t>( moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(), local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(), global_expert_count.data_ptr<int>(),
......
...@@ -67,6 +67,7 @@ def test_module(moe, linear, inp, gate): ...@@ -67,6 +67,7 @@ def test_module(moe, linear, inp, gate):
moe.zero_grad() moe.zero_grad()
x = (linear(inp)) x = (linear(inp))
output = moe(x, gate) output = moe(x, gate)
print('ooutput', torch.distributed.get_rank(), output)
y = output.mean() y = output.mean()
y.backward() y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
...@@ -86,8 +87,14 @@ def test(): ...@@ -86,8 +87,14 @@ def test():
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda() moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
else: else:
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda() moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
moe_raw.weight.data = moe.weight.data.clone() if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone()
else:
weight_array = [torch.empty_like(moe.weight.data).cpu()
for _ in range(world_size)]
torch.distributed.all_gather(weight_array, moe.weight.data.cpu())
moe_raw.weight.data = torch.cat(weight_array, dim=0).cuda()
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, gate = torch.randint(low=0,
...@@ -97,11 +104,12 @@ def test(): ...@@ -97,11 +104,12 @@ def test():
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda() # gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone()) moe_out = test_module(moe, linear, inp.clone(), gate.clone())
print('hhh')
return
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone()) raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias'] if world_size == 1:
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
else:
names = ['Out']
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err)) print('{} abs err {}'.format(name, err))
...@@ -134,8 +142,6 @@ def test_dp(): ...@@ -134,8 +142,6 @@ def test_dp():
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi') torch.distributed.init_process_group(backend='mpi')
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
if world_size == 1:
world_size = None
test() test()
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size())) # print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
# perf() # perf()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment