Commit 94eca783 authored by Rick Ho's avatar Rick Ho
Browse files

reset parameters in megatron

parent f6afdbee
r'''
Layers that FMoE provides to users
'''
import math
import torch
import torch.nn as nn
import numpy as np
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
......@@ -31,29 +29,6 @@ class FMoELinear(nn.Module):
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
r'''
Initialize the weight as linear layers
'''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
# copied from torch.nn.init.kaiming_uniform_
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count):
r'''
......
......@@ -6,6 +6,8 @@ See `examples/megatron` for usage instructions.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
......@@ -24,6 +26,26 @@ class _MegatronMLP(nn.Module):
return x, torch.zeros_like(x)
def _random_init_weight(self, rng):
r'''
Copied from torch.nn.init.kaiming_uniform_
'''
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
class MegatronMLP(FMoETransformerMLP):
r'''
Make the FMoETransformerMLP layer that distributes experts across
......@@ -43,6 +65,18 @@ class MegatronMLP(FMoETransformerMLP):
world_size=world_size, mp_group=group,
expert_dp_comm='none' if args.distributed_experts else 'dp')
self.hidden_size = args.hidden_size
self.rank = args.rank
self.reset_parameters()
def reset_parameters(self):
r'''
Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used.
'''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_random_init_weight(self.experts.htoh4, rng)
_random_init_weight(self.experts.h4toh, rng)
def forward(self, inp):
return super().forward(inp), torch.zeros(self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment