Unverified Commit 7d41fe88 authored by Tiago Antunes's avatar Tiago Antunes Committed by GitHub
Browse files

Added default weight initializations to FMoELinear and NoisyGate (#52)

* Added default weight initializations to FMoELinear and NoisyGate

* Following torch's naming convention
parent ec2d458d
......@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import math
class NoisyGate(BaseGate):
......@@ -24,6 +25,16 @@ class NoisyGate(BaseGate):
self.noise_epsilon = 1e-2
self.reset_parameters()
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5))
def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
......
......@@ -3,6 +3,7 @@ Layers that FMoE provides to users
"""
import torch
import torch.nn as nn
import math
from .functions import prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
......@@ -33,10 +34,12 @@ class FMoELinear(nn.Module):
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
......@@ -54,6 +57,13 @@ class FMoELinear(nn.Module):
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def mark_module_parallel_comm(module, comm):
r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment