Unverified Commit 89176e34 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] moe: remove G dimension (#183)

parent 5d4f50fb
......@@ -66,29 +66,28 @@ class MOELayer(Base):
p.expert = True # type: ignore
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.float(), input)
return _AllToAll.apply(self.group, dispatched_input)
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
expert_output = _AllToAll.apply(self.group, input)
return torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
return torch.einsum("sec,ecm->sm", combine_weights, expert_output)
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
assert input[0].shape[0] == len(self.experts), "group dimension size must be equal to number of local experts"
assert len(input[0].shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert input[0].shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
# Implement Algorithm 2 from GShard paper.
shape = input[0].shape
# Reshape into S tokens per group.
reshaped_input = input[0].reshape(shape[0], -1, shape[3])
# Reshape into S tokens by dropping sequence dimension.
reshaped_input = input[0].reshape(-1, shape[2])
self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input)
dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input)
assert dispatched_input.shape[1] == len(self.experts)
chunks = dispatched_input.chunk(len(self.experts), dim=1)
chunks = dispatched_input.chunk(len(self.experts), dim=0)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=1)
expert_output = torch.cat(expert_outputs, dim=0)
combined_output = self.all_to_all_combine(combine_weights, expert_output)
return combined_output.reshape(shape)
......@@ -28,36 +28,36 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=2)
gates = F.softmax(logits, dim=1)
# gates has shape of GSE
num_tokens = gates.shape[1]
num_experts = gates.shape[2]
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
# capacity = 2S/E
capacity = 2 * num_tokens // num_experts
assert num_tokens % num_experts == 0
# Create a mask for 1st's expert per token
indices1_gs = torch.argmax(gates, dim=2)
mask1 = F.one_hot(indices1_gs, num_classes=num_experts)
indices1_s = torch.argmax(gates, dim=1)
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_gs = torch.argmax(logits_except1, dim=2)
mask2 = F.one_hot(indices2_gs, num_classes=num_experts)
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=1) - 1
locations2 = torch.cumsum(mask2, dim=1) - 1
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=1, keepdim=True)
locations2 += torch.sum(mask1, dim=0, keepdim=True)
# Compute l_aux
me = torch.mean(gates, dim=1)
ce = torch.mean(mask1.float(), dim=1)
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce)
# Remove locations outside capacity from mask
......@@ -65,28 +65,28 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2 *= torch.lt(locations2, capacity)
# Store the capacity location for each token
locations1_gs = torch.sum(locations1 * mask1, dim=2)
locations2_gs = torch.sum(locations2 * mask2, dim=2)
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_s = torch.sum(locations2 * mask2, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
mask2_float = mask2.float()
gates1_gs = torch.einsum("gse,gse->gs", gates, mask1_float)
gates2_gs = torch.einsum("gse,gse->gs", gates, mask2_float)
denom_gs = gates1_gs + gates2_gs
gates1_s = torch.einsum("se,se->s", gates, mask1_float)
gates2_s = torch.einsum("se,se->s", gates, mask2_float)
denom_s = gates1_s + gates2_s
# Avoid divide-by-zero
denom_gs = torch.where(denom_gs > 0, denom_gs, torch.ones_like(denom_gs))
gates1_gs /= denom_gs
gates2_gs /= denom_gs
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s /= denom_s
gates2_s /= denom_s
# Calculate combine_weights and dispatch_mask
gates1 = torch.einsum("gs,gse->gse", gates1_gs, mask1_float)
gates2 = torch.einsum("gs,gse->gse", gates2_gs, mask2_float)
locations1_gsc = F.one_hot(locations1_gs, num_classes=capacity)
locations2_gsc = F.one_hot(locations2_gs, num_classes=capacity)
combine1_gsec = torch.einsum("gse,gsc->gsec", gates1, locations1_gsc)
combine2_gsec = torch.einsum("gse,gsc->gsec", gates2, locations2_gsc)
combine_weights = combine1_gsec + combine2_gsec
gates1 = torch.einsum("s,se->se", gates1_s, mask1_float)
gates2 = torch.einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = F.one_hot(locations1_s, num_classes=capacity)
locations2_sc = F.one_hot(locations2_s, num_classes=capacity)
combine1_sec = torch.einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = torch.einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask
......@@ -115,5 +115,5 @@ class Top2Gate(torch.nn.Module):
self.wg = torch.nn.Linear(num_experts, model_dim, bias=False)
def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
logits = torch.einsum("gsm,me -> gse", input, self.wg.weight)
logits = torch.einsum("sm,me -> se", input, self.wg.weight)
return top2gating(logits)
......@@ -37,6 +37,10 @@ from . import version
class dtype:
is_floating_point: builtins.bool
class finfo:
def __init__(self, dtype: dtype): ...
eps: float
class layout: ...
strided : layout = ...
......
......@@ -23,18 +23,17 @@ else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ:
pass # dist.init_process_group(backend=dist.Backend.MPI)
dist.init_process_group(backend=dist.Backend.MPI)
def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
else:
dist.init_process_group(backend=dist.Backend.MPI)
def teardown_module(module):
torch.distributed.destroy_process_group()
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
torch.distributed.destroy_process_group()
@pytest.mark.parametrize("device", devices)
......@@ -62,7 +61,7 @@ def test_expert_params(device):
def test_forward(device):
model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(1, 4, 16, model_dim).to(device)
input = torch.randn(4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
......@@ -81,7 +80,7 @@ def test_forward_multi(device):
num_local_experts = 4
model_dim = 4
num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
input = torch.randn(num_local_experts, 4, 16, model_dim).to(device)
input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
experts = []
for i in range(num_local_experts):
......@@ -106,12 +105,12 @@ class RoundRobinGate(torch.nn.Module):
self.num_experts = num_experts
def forward(self, input):
g, s, _ = input.shape
s = input.shape[0]
assert s % self.num_experts == 0
capacity = 2 * s // self.num_experts
output = torch.zeros(g, s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
for i in range(s):
output[:, i, i % self.num_experts, i // self.num_experts] = 1.0
output[i, i % self.num_experts, i // self.num_experts] = 1.0
return 0.0, output, output.bool()
......@@ -120,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def test_forward_routing(device):
model_dim = 8
num_experts = dist.get_world_size()
input = torch.randn(1, 4, 16, model_dim).to(device)
input = torch.randn(4, 16, model_dim).to(device)
gate = RoundRobinGate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use scaling matrix (each rank has a different scale)
......@@ -130,10 +129,10 @@ def test_forward_routing(device):
output = moe(input)
assert output.shape == input.shape
# Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[2]
t = input.shape[1]
for i in range(t):
expert = i % num_experts
assert torch.allclose(input[:, :, i] * (expert + 1), output[:, :, i])
assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
@pytest.mark.mpi
......@@ -142,7 +141,7 @@ def test_backward(device):
loss = torch.nn.MSELoss()
model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(1, 4, 16, model_dim).to(device)
input = torch.randn(4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
......
......@@ -23,21 +23,21 @@ def test_create_cuda():
def do_test_forward(device):
torch.manual_seed(3)
input = torch.randn(3, 12, 4).to(device)
input = torch.randn(12, 4).to(device)
gate = Top2Gate(4, 6).to(device)
capacity = 2 * 12 // 6
l_aux, combine_weights, dispatch_mask = gate(input)
assert pytest.approx(l_aux.item(), 0.0283)
assert combine_weights.shape == (3, 12, 6, 4)
assert dispatch_mask.shape == (3, 12, 6, 4)
assert combine_weights.shape == (12, 6, 4)
assert dispatch_mask.shape == (12, 6, 4)
assert torch.equal(combine_weights.bool(), dispatch_mask)
assert torch.all(torch.sum(dispatch_mask, axis=(1, 3)) <= capacity)
assert torch.all(torch.sum(dispatch_mask, axis=(0, 2)) <= capacity)
assert torch.all(combine_weights >= 0.0)
assert torch.all(combine_weights <= 1.0)
weights_sum = torch.sum(combine_weights).item()
assert round(weights_sum) == pytest.approx(weights_sum)
# For this random seed, we get 36 slots filled.
assert weights_sum == pytest.approx(36.0)
# For this random seed, we get 12 slots filled.
assert weights_sum == pytest.approx(12.0)
def test_forward_cpu():
......@@ -53,15 +53,15 @@ def test_forward_cuda():
def test_top1s():
num_tokens = 8
num_experts = 4
logits = torch.randn(1, num_tokens, num_experts)
logits = torch.randn(num_tokens, num_experts)
l_aux, _, dispatch_mask = top2gating(logits)
top1s = torch.argmax(logits, dim=2)
top1s = torch.argmax(logits, dim=1)
capacity = 2 * num_tokens // num_experts
ce = [0] * num_experts
locations = [0] * num_tokens
for i, s in enumerate(top1s[0]):
for i, s in enumerate(top1s):
e = s.item()
loc = ce[e]
ce[e] = loc + 1
if ce[e] < capacity:
assert dispatch_mask[0][i][e][loc]
assert dispatch_mask[i][e][loc]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment