Commit 527a8cc9 authored by Rick Ho's avatar Rick Ho
Browse files

fix numerical issue in tests and typo

parent 12a7921b
...@@ -30,7 +30,7 @@ class _FakeMegatronMLP(nn.Module): ...@@ -30,7 +30,7 @@ class _FakeMegatronMLP(nn.Module):
x = self.fc2(x) x = self.fc2(x)
return x, torch.zeros_like(x) return x, torch.zeros_like(x)
def _magatron_init_method(self, rng, sigma): def _megatron_init_method(self, rng, sigma):
r''' r'''
Init method based on N(0, sigma). Init method based on N(0, sigma).
Copied from Megatron-LM Copied from Megatron-LM
...@@ -99,9 +99,9 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -99,9 +99,9 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used. additional numpy rng is used.
''' '''
rng = np.random.default_rng(np.random.randint(2048) + self.rank) rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_magatron_init_method(self.experts.htoh4, rng, self.sigma) _megatron_init_method(self.experts.htoh4, rng, self.sigma)
std = self.sigma / math.sqrt(2.0 * self.num_layers) std = self.sigma / math.sqrt(2.0 * self.num_layers)
_magatron_init_method(self.experts.h4toh, rng, std) _megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp): def forward(self, inp):
return super().forward(inp), torch.zeros(self.hidden_size, return super().forward(inp), torch.zeros(self.hidden_size,
......
...@@ -5,12 +5,14 @@ from typing import List, Type, Union ...@@ -5,12 +5,14 @@ from typing import List, Type, Union
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from copy import deepcopy from copy import deepcopy
from fmoe.gates import NaiveGate from fmoe.gates import NaiveGate
from fmoe.layers import FMoE from fmoe.layers import FMoE
from fmoe.transformer import _Expert from fmoe.transformer import _Expert
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
from fmoe.megatron import _megatron_init_method
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
...@@ -66,6 +68,7 @@ class MyMoE(FMoE): ...@@ -66,6 +68,7 @@ class MyMoE(FMoE):
super().__init__( super().__init__(
num_expert=num_expert, num_expert=num_expert,
d_model=d_model, d_model=d_model,
gate=NaiveGate, gate=NaiveGate,
world_size=world_size, world_size=world_size,
mp_group=mp_group, mp_group=mp_group,
...@@ -73,6 +76,10 @@ class MyMoE(FMoE): ...@@ -73,6 +76,10 @@ class MyMoE(FMoE):
) )
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation)
rng = np.random.default_rng(1234)
_megatron_init_method(self.experts.htoh4, rng, 1.)
_megatron_init_method(self.experts.h4toh, rng, 1.)
@pytest.mark.parametrize("num_expert", [4, 8]) @pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3]) @pytest.mark.parametrize("top_k", [2, 3])
...@@ -353,9 +360,9 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group): ...@@ -353,9 +360,9 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
if __name__ == "__main__": if __name__ == "__main__":
test_fmoe_linear( test_fmoe_linear(
batch_size=4, batch_size=2,
num_expert=4, num_expert=2,
d_model=8, d_model=2,
top_k=2, top_k=2,
d_hidden=16, d_hidden=16,
rank=0, rank=0,
...@@ -364,15 +371,3 @@ if __name__ == "__main__": ...@@ -364,15 +371,3 @@ if __name__ == "__main__":
dp_group=None, dp_group=None,
world_group=None, world_group=None,
) )
test_fmoe(
batch_size=4,
num_expert=4,
d_model=8,
top_k=2,
expert=NaiveExpert,
rank=0,
world_size=1,
mp_group=None,
dp_group=None,
world_group=None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment