Commit 72e9bc9e authored by Rick Ho's avatar Rick Ho
Browse files

scatter gather kernel support for non-equal shapes and fix tests

parent 49c97411
...@@ -228,10 +228,13 @@ std::vector<torch::Tensor> moe_cuda_local_scatter( ...@@ -228,10 +228,13 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input, torch::Tensor input,
torch::Tensor pos) { torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index()); auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = input.size(0); const auto batch_size = pos.size(0);
const auto in_feat = input.size(1); const auto in_feat = input.size(1);
auto input_buf = torch::empty_like(input); auto opt = torch::TensorOptions()
.dtype(input.dtype())
.device(input.device());
auto input_buf = torch::empty({batch_size, in_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda",
([&] { ([&] {
...@@ -250,10 +253,13 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -250,10 +253,13 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor pos) { torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index()); auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = output_buf.size(0); const auto batch_size = pos.size(0);
const auto out_feat = output_buf.size(1); const auto out_feat = output_buf.size(1);
auto output = torch::empty_like(output_buf); auto opt = torch::TensorOptions()
.dtype(output_buf.dtype())
.device(output_buf.device());
auto output = torch::empty({batch_size, out_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] { ([&] {
......
...@@ -88,9 +88,13 @@ def test_module(moe, linear, inp, gate): ...@@ -88,9 +88,13 @@ def test_module(moe, linear, inp, gate):
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
rank = None
world_size = None
def test(): def test():
torch.manual_seed(42 + torch.distributed.get_rank()) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + torch.distributed.get_rank()) torch.cuda.manual_seed(42 + rank)
batch_size = 4 batch_size = 4
num_expert = 2 num_expert = 2
in_feat = 6 in_feat = 6
...@@ -123,13 +127,10 @@ def test(): ...@@ -123,13 +127,10 @@ def test():
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias'] names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
if world_size > 1: if world_size > 1:
rank = torch.distributed.get_rank()
ou, wg, lwg, lbg = raw_out ou, wg, lwg, lbg = raw_out
torch.distributed.all_reduce(wg) torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert] wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg, lwg, lbg raw_out = ou, wg, lwg, lbg
else:
rank = 0
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('Rank {} {} abs err {}'.format(rank, name, err)) print('Rank {} {} abs err {}'.format(rank, name, err))
...@@ -166,11 +167,15 @@ def test_dp(): ...@@ -166,11 +167,15 @@ def test_dp():
if __name__ == '__main__': if __name__ == '__main__':
os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', 0) os.environ['RANK'] = os.environ.get('OMPI_COMM_WORLD_RANK', '0')
os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', 1) os.environ['WORLD_SIZE'] = os.environ.get('OMPI_COMM_WORLD_SIZE', '1')
torch.distributed.init_process_group(backend='nccl') if int(os.environ['WORLD_SIZE']) > 1:
rank = torch.distributed.get_rank() torch.distributed.init_process_group(backend='nccl')
world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1
if len(sys.argv) >= 2: if len(sys.argv) >= 2:
task = sys.argv[1] task = sys.argv[1]
print('Specificed task {}'.format(task)) print('Specificed task {}'.format(task))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment