Commit da4a8e5e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix import error and dtype/device problem in MoE

parent 5f5ccd47
......@@ -3,6 +3,8 @@ Layers that FMoE provides to users
'''
import torch
import torch.nn as nn
import numpy as np
import math
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
......@@ -31,13 +33,18 @@ class FMoELinear(nn.Module):
Initialize the weight as linear layers
'''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
# copied from https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
device = self.weight.device
dtype = self.weight.dtype
for i in range(self.num_expert):
weight = rng.uniform(-bound, bound, size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.from_numpy(weight)
self.weight.data[i] = torch.tensor(weight, dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count):
r'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment