"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "c877cda61fee6b0fa77f5d8faaa985ad00fc2cab"
Unverified Commit 593feab2 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #10 from laekov/megatron-mlp-init

use magatron's init method for ffn
parents 5e5b4044 8ddd246f
...@@ -30,6 +30,20 @@ class _FakeMegatronMLP(nn.Module): ...@@ -30,6 +30,20 @@ class _FakeMegatronMLP(nn.Module):
x = self.fc2(x) x = self.fc2(x)
return x, torch.zeros_like(x) return x, torch.zeros_like(x)
def _magatron_init_method(self, rng, sigma):
r'''
Init method based on N(0, sigma).
Copied from Megatron-LM
'''
device = self.weight.device
dtype = self.weight.dtype
weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
if self.bias is not None:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
def _random_init_weight(self, rng): def _random_init_weight(self, rng):
r''' r'''
...@@ -71,6 +85,8 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -71,6 +85,8 @@ class MegatronMLP(FMoETransformerMLP):
expert_dp_comm='none' if args.distributed_experts else 'dp') expert_dp_comm='none' if args.distributed_experts else 'dp')
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.rank = args.rank self.rank = args.rank
self.sigma = args.init_method_std
self.num_layers = args.num_layers
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -80,8 +96,9 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -80,8 +96,9 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used. additional numpy rng is used.
''' '''
rng = np.random.default_rng(np.random.randint(2048) + self.rank) rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_random_init_weight(self.experts.htoh4, rng) _magatron_init_method(self.experts.htoh4, rng, self.sigma)
_random_init_weight(self.experts.h4toh, rng) std = self.sigma / math.sqrt(2.0 * self.num_layers)
_magatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp): def forward(self, inp):
return super().forward(inp), torch.zeros(self.hidden_size, return super().forward(inp), torch.zeros(self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment