"docs/source/vscode:/vscode.git/clone" did not exist on "dcfb7a99a56b294b5f3036fd5f5d3613a1511311"
Commit 527a8cc9 authored by Rick Ho's avatar Rick Ho
Browse files

fix numerical issue in tests and typo

parent 12a7921b
......@@ -30,7 +30,7 @@ class _FakeMegatronMLP(nn.Module):
x = self.fc2(x)
return x, torch.zeros_like(x)
def _magatron_init_method(self, rng, sigma):
def _megatron_init_method(self, rng, sigma):
r'''
Init method based on N(0, sigma).
Copied from Megatron-LM
......@@ -99,9 +99,9 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used.
'''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_magatron_init_method(self.experts.htoh4, rng, self.sigma)
_megatron_init_method(self.experts.htoh4, rng, self.sigma)
std = self.sigma / math.sqrt(2.0 * self.num_layers)
_magatron_init_method(self.experts.h4toh, rng, std)
_megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp):
return super().forward(inp), torch.zeros(self.hidden_size,
......
......@@ -5,12 +5,14 @@ from typing import List, Type, Union
import pytest
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
from fmoe.megatron import _megatron_init_method
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
......@@ -66,6 +68,7 @@ class MyMoE(FMoE):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=NaiveGate,
world_size=world_size,
mp_group=mp_group,
......@@ -73,6 +76,10 @@ class MyMoE(FMoE):
)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
rng = np.random.default_rng(1234)
_megatron_init_method(self.experts.htoh4, rng, 1.)
_megatron_init_method(self.experts.h4toh, rng, 1.)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
......@@ -353,9 +360,9 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
if __name__ == "__main__":
test_fmoe_linear(
batch_size=4,
num_expert=4,
d_model=8,
batch_size=2,
num_expert=2,
d_model=2,
top_k=2,
d_hidden=16,
rank=0,
......@@ -364,15 +371,3 @@ if __name__ == "__main__":
dp_group=None,
world_group=None,
)
test_fmoe(
batch_size=4,
num_expert=4,
d_model=8,
top_k=2,
expert=NaiveExpert,
rank=0,
world_size=1,
mp_group=None,
dp_group=None,
world_group=None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment