"src/diffusers/models/modeling_utils.py" did not exist on "96598639c072d7f6dadf173c8c53ddcd4abfc6e5"
Commit 3b82e379 authored by Rick Ho's avatar Rick Ho
Browse files

fix lint

parent 43af1522
......@@ -47,7 +47,8 @@ class DistributedGroupedDataParallel(nn.Module):
else:
self.comms["world"] = world_group
def allreduce_params(no_scale=False, reduce_after=False, fp32_allreduce=False):
def allreduce_params(no_scale=False,
reduce_after=False, fp32_allreduce=False):
groups = dict()
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
......
......@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
)
else:
global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0)
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item())
return (
pos,
......
......@@ -23,7 +23,8 @@ class ZeroGate(nn.Module):
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
score = torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
score = torch.ones(inp.shape[0] * self.top_k,
device=inp.device) / self.top_k
return idx, score.reshape(-1, 1, self.top_k)
......
......@@ -114,7 +114,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
fwd_batch_size,
) = moe_prepare_forward(gate, num_expert, world_size)
x = MOEScatter.apply(
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size
inp, pos,
local_expert_count, global_expert_count, fwd_batch_size, world_size
)
x = expert_fn(x, fwd_expert_count)
x = MOEGather.apply(
......@@ -165,7 +166,8 @@ class FMoE(nn.Module):
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
if expert is not None:
self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)])
self.experts_fused = False
else:
self.experts_fused = True
......
......@@ -41,7 +41,7 @@ def _megatron_init_method(self, rng, sigma):
device = self.weight.device
dtype = self.weight.dtype
weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
if self.bias is not None:
# Always initialize bias to zero.
......@@ -60,13 +60,13 @@ def _random_init_weight(self, rng):
device = self.weight.device
dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
class MegatronMLP(FMoETransformerMLP):
......@@ -77,7 +77,8 @@ class MegatronMLP(FMoETransformerMLP):
def __init__(self, args, group):
assert (
args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
args.seq_length * args.micro_batch_size
% args.tensor_model_parallel_size
== 0
), "Batch size x sequence length should be multiple of mp size"
if not args.distributed_experts:
......
......@@ -15,8 +15,10 @@ class _Expert(nn.Module):
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True,
rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True,
rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment