Commit 3b82e379 authored by Rick Ho's avatar Rick Ho
Browse files

fix lint

parent 43af1522
...@@ -47,7 +47,8 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -47,7 +47,8 @@ class DistributedGroupedDataParallel(nn.Module):
else: else:
self.comms["world"] = world_group self.comms["world"] = world_group
def allreduce_params(no_scale=False, reduce_after=False, fp32_allreduce=False): def allreduce_params(no_scale=False,
reduce_after=False, fp32_allreduce=False):
groups = dict() groups = dict()
for p in self.module.parameters(): for p in self.module.parameters():
if not p.requires_grad or p.grad is None: if not p.requires_grad or p.grad is None:
......
...@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
) )
else: else:
global_expert_count = local_expert_count global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0) fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item()) fwd_batch_size = int(fwd_expert_count.sum().item())
return ( return (
pos, pos,
......
...@@ -23,7 +23,8 @@ class ZeroGate(nn.Module): ...@@ -23,7 +23,8 @@ class ZeroGate(nn.Module):
idx = torch.zeros( idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
) )
score = torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k score = torch.ones(inp.shape[0] * self.top_k,
device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k) return idx, score.reshape(-1, 1, self.top_k)
......
...@@ -114,7 +114,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -114,7 +114,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
fwd_batch_size, fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size) ) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply( x = MOEScatter.apply(
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size inp, pos,
local_expert_count, global_expert_count, fwd_batch_size, world_size
) )
x = expert_fn(x, fwd_expert_count) x = expert_fn(x, fwd_expert_count)
x = MOEGather.apply( x = MOEGather.apply(
...@@ -165,7 +166,8 @@ class FMoE(nn.Module): ...@@ -165,7 +166,8 @@ class FMoE(nn.Module):
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
if expert is not None: if expert is not None:
self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)]) self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)])
self.experts_fused = False self.experts_fused = False
else: else:
self.experts_fused = True self.experts_fused = True
......
...@@ -41,7 +41,7 @@ def _megatron_init_method(self, rng, sigma): ...@@ -41,7 +41,7 @@ def _megatron_init_method(self, rng, sigma):
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size())) weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device) self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
if self.bias is not None: if self.bias is not None:
# Always initialize bias to zero. # Always initialize bias to zero.
...@@ -60,13 +60,13 @@ def _random_init_weight(self, rng): ...@@ -60,13 +60,13 @@ def _random_init_weight(self, rng):
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size())) weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device) self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
if self.bias is not None: if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size())) bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device) self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
class MegatronMLP(FMoETransformerMLP): class MegatronMLP(FMoETransformerMLP):
...@@ -77,7 +77,8 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -77,7 +77,8 @@ class MegatronMLP(FMoETransformerMLP):
def __init__(self, args, group): def __init__(self, args, group):
assert ( assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size args.seq_length * args.micro_batch_size
% args.tensor_model_parallel_size
== 0 == 0
), "Batch size x sequence length should be multiple of mp size" ), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts: if not args.distributed_experts:
......
...@@ -15,8 +15,10 @@ class _Expert(nn.Module): ...@@ -15,8 +15,10 @@ class _Expert(nn.Module):
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__() super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank) self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True,
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank) rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True,
rank=rank)
self.activation = activation self.activation = activation
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment