Commit 670d70ed authored by Rick Ho's avatar Rick Ho
Browse files

pass test on single gpu

parent 52e57316
#include "../local_exchange.cuh"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cuda.h>
#include <cuda_runtime.h>
int main(int argc, char* args[]) {
int n_worker = atoi(args[1]);
int n_expert = atoi(args[2]);
int batch_size = atoi(args[3]);
int topk = atoi(args[4]);
int tot_expert = n_worker * n_expert;
long* gate_idx = new long[batch_size * topk];
long* n_gate_idx = new long[batch_size * topk];
int* lec = new int[tot_expert];
memset(lec, 0, sizeof(int) * tot_expert);
for (int i = 0; i < batch_size * topk; ++i) {
if (rand() % 10) {
gate_idx[i] = rand() % tot_expert;
++lec[gate_idx[i]];
} else {
gate_idx[i] = -1;
}
}
for (int i = 1; i < tot_expert; ++i) {
lec[i] += lec[i - 1];
}
puts("gate idx");
for (int i = 0; i < batch_size * topk; ++i) {
printf("%d ", gate_idx[i]);
}
putchar(10);
int nlec = lec[tot_expert - 1];
int* g_lec;
cudaMalloc(&g_lec, sizeof(int) * tot_expert);
cudaMemcpy(g_lec, lec, sizeof(int) * tot_expert, cudaMemcpyHostToDevice);
long* g_gate_idx;
cudaMalloc(&g_gate_idx, sizeof(long) * batch_size * topk);
cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * batch_size * topk,
cudaMemcpyHostToDevice);
long* g_pos;
cudaMalloc(&g_pos, sizeof(long) * nlec);
// cudaMemcpy(g_gate_idx, gate_idx, sizeof(long) * nlec, cudaMemcpyHostToDevice);
auto smgr = getCudaStreamManager(0);
fmoe_cuda_assign_pos_impl(g_lec, g_gate_idx, g_pos, batch_size * topk,
topk, smgr);
long* pos = new long[nlec];
cudaMemcpy(pos, g_pos, sizeof(long) * nlec, cudaMemcpyDeviceToHost);
puts("pos");
for (int i = 0; i < nlec; ++i) {
printf("%d ", pos[i]);
}
putchar(10);
}
...@@ -79,10 +79,13 @@ def _local_scatter(inp, pos): ...@@ -79,10 +79,13 @@ def _local_scatter(inp, pos):
return inp_buf return inp_buf
def _local_gather(inp, pos, out_batch_size): def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
inp_buf = torch.zeros(out_batch_size, inp.shape[-1], inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
dtype=inp.dtype, device=inp.device) dtype=inp.dtype, device=inp.device)
inp_buf.index_copy_(0, pos, inp) if maybe_overlap:
inp_buf.index_add_(0, pos, inp)
else:
inp_buf.index_copy_(0, pos, inp)
return inp_buf return inp_buf
...@@ -187,7 +190,8 @@ class MOEGather(Function): ...@@ -187,7 +190,8 @@ class MOEGather(Function):
) )
else: else:
local_output_buf = global_output_buf local_output_buf = global_output_buf
output = _local_gather(local_output_buf, pos, local_batch_size) output = _local_gather(local_output_buf, pos, local_batch_size,
maybe_overlap=False)
ctx.moe_args = (global_output_buf.shape[0], world_size) ctx.moe_args = (global_output_buf.shape[0], world_size)
variables = (pos, local_expert_count, global_expert_count) variables = (pos, local_expert_count, global_expert_count)
......
...@@ -113,8 +113,11 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -113,8 +113,11 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
fwd_expert_count, fwd_expert_count,
fwd_batch_size, fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size) ) = prepare_forward(gate, num_expert, world_size)
topk = 1
if len(gate.shape) == 2:
topk = gate.shape[1]
x = MOEScatter.apply( x = MOEScatter.apply(
inp, pos % inp.shape[0], inp, pos // topk,
local_expert_count, global_expert_count, fwd_batch_size, world_size local_expert_count, global_expert_count, fwd_batch_size, world_size
) )
x = expert_fn(x, fwd_expert_count) x = expert_fn(x, fwd_expert_count)
......
...@@ -42,6 +42,7 @@ class BruteForceMoELinear(nn.Module): ...@@ -42,6 +42,7 @@ class BruteForceMoELinear(nn.Module):
x = x + self.bias_h4toh[i] x = x + self.bias_h4toh[i]
o[idx] = x o[idx] = x
gate_score = gate_score.unsqueeze(1) gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape( x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model -1, self.d_model
) )
......
import sys
from collections import OrderedDict
from typing import List, Type, Union
import pytest
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from fmoe.functions import MOEGather, MOEScatter, count_by_gate
from test_numerical import _assert_numercial
@pytest.mark.parametrize("n_expert", [1, 4, 8])
@pytest.mark.parametrize("topk", [1, 2])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("d_model", [6])
@pytest.mark.parametrize("world_size", [1])
def test_scatter(n_expert, topk, batch_size, d_model, world_size):
gate_idx = torch.randint(n_expert + 1, (batch_size, topk)) - 1
gate_idx = gate_idx.long().cuda()
pos, lec, gec = count_by_gate(gate_idx, n_expert, world_size)
fbs = int(gec.sum().item())
inp = torch.rand(batch_size, d_model).cuda()
inp.requires_grad = True
out = MOEScatter.apply(inp, pos % batch_size, lec, gec, fbs, world_size)
out.sum().backward()
inp_raw = inp.data.clone()
out_raw = torch.empty(pos.shape[0], d_model,
device=inp.device, dtype=inp.dtype)
out_raw.sum().backward()
for i, f in enumerate(pos.cpu()):
out_raw[i] = inp[f % batch_size]
_assert_numercial(['out'], [out], [out_raw], 0)
# TODO: check grad
if __name__ == '__main__':
test_scatter(4, 2, 8, 6, 1)
...@@ -21,11 +21,9 @@ def _perform_forward( ...@@ -21,11 +21,9 @@ def _perform_forward(
): ):
moe.zero_grad() moe.zero_grad()
moe_raw.zero_grad() moe_raw.zero_grad()
if not mp_group: inp = torch.rand(batch_size, d_model).cuda()
inp = torch.rand(batch_size, d_model).cuda() if mp_group is not None:
else:
group_sender = rank // mp_group.size() * mp_group.size() group_sender = rank // mp_group.size() * mp_group.size()
inp = torch.rand(batch_size, d_model).cuda()
torch.distributed.broadcast(inp, group_sender, group=mp_group) torch.distributed.broadcast(inp, group_sender, group=mp_group)
torch.distributed.broadcast( torch.distributed.broadcast(
moe.gate.gate.weight.data, group_sender, group=mp_group moe.gate.gate.weight.data, group_sender, group=mp_group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment