Commit da4a8e5e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix import error and dtype/device problem in MoE

parent 5f5ccd47
...@@ -3,6 +3,8 @@ Layers that FMoE provides to users ...@@ -3,6 +3,8 @@ Layers that FMoE provides to users
''' '''
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
import math
from .functions import moe_prepare_forward from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
...@@ -31,13 +33,18 @@ class FMoELinear(nn.Module): ...@@ -31,13 +33,18 @@ class FMoELinear(nn.Module):
Initialize the weight as linear layers Initialize the weight as linear layers
''' '''
rng = np.random.default_rng(np.random.randint(2048) + self.rank) rng = np.random.default_rng(np.random.randint(2048) + self.rank)
# copied from https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in') fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5)) gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan) std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
device = self.weight.device
dtype = self.weight.dtype
for i in range(self.num_expert): for i in range(self.num_expert):
weight = rng.uniform(-bound, bound, size=tuple(self.weight[i].size())) weight = rng.uniform(-bound, bound, size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.from_numpy(weight) self.weight.data[i] = torch.tensor(weight, dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment