Unverified Commit 7d41fe88 authored by Tiago Antunes's avatar Tiago Antunes Committed by GitHub
Browse files

Added default weight initializations to FMoELinear and NoisyGate (#52)

* Added default weight initializations to FMoELinear and NoisyGate

* Following torch's naming convention
parent ec2d458d
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributions.normal import Normal from torch.distributions.normal import Normal
import math
class NoisyGate(BaseGate): class NoisyGate(BaseGate):
...@@ -24,6 +25,16 @@ class NoisyGate(BaseGate): ...@@ -24,6 +25,16 @@ class NoisyGate(BaseGate):
self.noise_epsilon = 1e-2 self.noise_epsilon = 1e-2
self.reset_parameters()
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5))
def _gates_to_load(self, gates): def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates. """Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0. The load is the number of examples for which the corresponding gate is >0.
......
...@@ -3,6 +3,7 @@ Layers that FMoE provides to users ...@@ -3,6 +3,7 @@ Layers that FMoE provides to users
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
import math
from .functions import prepare_forward from .functions import prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
...@@ -33,10 +34,12 @@ class FMoELinear(nn.Module): ...@@ -33,10 +34,12 @@ class FMoELinear(nn.Module):
self.rank = rank self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias: if bias:
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat)) self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r""" r"""
Call MOE function Call MOE function
...@@ -54,6 +57,13 @@ class FMoELinear(nn.Module): ...@@ -54,6 +57,13 @@ class FMoELinear(nn.Module):
self.rank, self.rank,
) )
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment