Commit 6acc3e41 authored by Rick Ho's avatar Rick Ho
Browse files

new test structure according to new expert_fn structure

parent c0a3a425
...@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert ...@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
n_devices = int(os.environ.get("N_GPUS", "2")) n_devices = int(os.environ.get("N_GPUS", "2"))
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=1,
mp_group=None,
top_k=top_k
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3]) @pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -26,22 +39,12 @@ def test_fmoe_dp( ...@@ -26,22 +39,12 @@ def test_fmoe_dp(
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda() moe = MyMoE(num_expert, d_model, d_hidden, top_k, activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=1,
mp_group=None,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices))) moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices)))
for i in range(5): for i in range(5):
output = moe_dp(torch.rand(batch_size, d_model).cuda()) output = moe_dp(torch.rand(batch_size, d_model).cuda())
if __name__ == '__main__':
test_fmoe_dp(4, 2, 4, 16, 32)
...@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank): ...@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
assert False assert False
class MyMoE(FMoE):
def __init__(self, num_expert, d_model, d_hidden, world_size, mp_group,
top_k, activation):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
top_k=top_k
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3]) @pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -74,20 +88,8 @@ def test_fmoe_linear( ...@@ -74,20 +88,8 @@ def test_fmoe_linear(
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
experts = _Expert(num_expert, d_model, d_hidden, activation).cuda() moe = MyMoE(num_expert, d_model, d_hidden, world_size, mp_group, top_k,
activation).cuda()
def expert_fn(inp, gate):
return experts(inp, gate)
moe = FMoE(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
expert_fn=expert_fn,
top_k=top_k,
).cuda()
moe_raw = BruteForceMoELinear( moe_raw = BruteForceMoELinear(
activation=activation, activation=activation,
...@@ -99,30 +101,30 @@ def test_fmoe_linear( ...@@ -99,30 +101,30 @@ def test_fmoe_linear(
).cuda() ).cuda()
if world_size == 1: if world_size == 1:
moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone() moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
moe_raw.bias_htoh4.data = experts.htoh4.bias.data.clone() moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone() moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
moe_raw.bias_h4toh.data = experts.h4toh.bias.data.clone() moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
else: else:
weight_htoh4_array = [ weight_htoh4_array = [
torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size) torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
] ]
bias_htoh4_array = [ bias_htoh4_array = [
torch.empty_like(experts.htoh4.bias.data) for _ in range(world_size) torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
] ]
torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data) torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
torch.distributed.all_gather(bias_htoh4_array, experts.htoh4.bias.data) torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0) moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0) moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
weight_h4toh_array = [ weight_h4toh_array = [
torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size) torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
] ]
bias_h4toh_array = [ bias_h4toh_array = [
torch.empty_like(experts.h4toh.bias.data) for _ in range(world_size) torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
] ]
torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data) torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
torch.distributed.all_gather(bias_h4toh_array, experts.h4toh.bias.data) torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0) moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0) moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
...@@ -130,7 +132,7 @@ def test_fmoe_linear( ...@@ -130,7 +132,7 @@ def test_fmoe_linear(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
) )
moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad moe_out_list = moe_out, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad
raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
if world_size > 1: if world_size > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment