Unverified Commit 89176e34 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] moe: remove G dimension (#183)

parent 5d4f50fb
...@@ -66,29 +66,28 @@ class MOELayer(Base): ...@@ -66,29 +66,28 @@ class MOELayer(Base):
p.expert = True # type: ignore p.expert = True # type: ignore
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor: def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input) dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.float(), input)
return _AllToAll.apply(self.group, dispatched_input) return _AllToAll.apply(self.group, dispatched_input)
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor: def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
expert_output = _AllToAll.apply(self.group, input) expert_output = _AllToAll.apply(self.group, input)
return torch.einsum("gsec,egcm->gsm", combine_weights, expert_output) return torch.einsum("sec,ecm->sm", combine_weights, expert_output)
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported" assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel" assert len(input[0].shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert input[0].shape[0] == len(self.experts), "group dimension size must be equal to number of local experts" assert input[0].shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
# Implement Algorithm 2 from GShard paper. # Implement Algorithm 2 from GShard paper.
shape = input[0].shape shape = input[0].shape
# Reshape into S tokens per group. # Reshape into S tokens by dropping sequence dimension.
reshaped_input = input[0].reshape(shape[0], -1, shape[3]) reshaped_input = input[0].reshape(-1, shape[2])
self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input) self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input)
dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input) dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input)
assert dispatched_input.shape[1] == len(self.experts) chunks = dispatched_input.chunk(len(self.experts), dim=0)
chunks = dispatched_input.chunk(len(self.experts), dim=1)
expert_outputs = [] expert_outputs = []
for chunk, expert in zip(chunks, self.experts): for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)] expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=1) expert_output = torch.cat(expert_outputs, dim=0)
combined_output = self.all_to_all_combine(combine_weights, expert_output) combined_output = self.all_to_all_combine(combine_weights, expert_output)
return combined_output.reshape(shape) return combined_output.reshape(shape)
...@@ -28,36 +28,36 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: ...@@ -28,36 +28,36 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits.""" """Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=2) gates = F.softmax(logits, dim=1)
# gates has shape of GSE # gates has shape of SE
num_tokens = gates.shape[1] num_tokens = gates.shape[0]
num_experts = gates.shape[2] num_experts = gates.shape[1]
# capacity = 2S/E # capacity = 2S/E
capacity = 2 * num_tokens // num_experts capacity = 2 * num_tokens // num_experts
assert num_tokens % num_experts == 0 assert num_tokens % num_experts == 0
# Create a mask for 1st's expert per token # Create a mask for 1st's expert per token
indices1_gs = torch.argmax(gates, dim=2) indices1_s = torch.argmax(gates, dim=1)
mask1 = F.one_hot(indices1_gs, num_classes=num_experts) mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick # Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value # Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_gs = torch.argmax(logits_except1, dim=2) indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_gs, num_classes=num_experts) mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# Compute locations in capacity buffer # Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=1) - 1 locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=1) - 1 locations2 = torch.cumsum(mask2, dim=0) - 1
# Update 2nd's location by accounting for locations of 1st # Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=1, keepdim=True) locations2 += torch.sum(mask1, dim=0, keepdim=True)
# Compute l_aux # Compute l_aux
me = torch.mean(gates, dim=1) me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=1) ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) l_aux = torch.mean(me * ce)
# Remove locations outside capacity from mask # Remove locations outside capacity from mask
...@@ -65,28 +65,28 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: ...@@ -65,28 +65,28 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2 *= torch.lt(locations2, capacity) mask2 *= torch.lt(locations2, capacity)
# Store the capacity location for each token # Store the capacity location for each token
locations1_gs = torch.sum(locations1 * mask1, dim=2) locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_gs = torch.sum(locations2 * mask2, dim=2) locations2_s = torch.sum(locations2 * mask2, dim=1)
# Normalize gate probabilities # Normalize gate probabilities
mask1_float = mask1.float() mask1_float = mask1.float()
mask2_float = mask2.float() mask2_float = mask2.float()
gates1_gs = torch.einsum("gse,gse->gs", gates, mask1_float) gates1_s = torch.einsum("se,se->s", gates, mask1_float)
gates2_gs = torch.einsum("gse,gse->gs", gates, mask2_float) gates2_s = torch.einsum("se,se->s", gates, mask2_float)
denom_gs = gates1_gs + gates2_gs denom_s = gates1_s + gates2_s
# Avoid divide-by-zero # Avoid divide-by-zero
denom_gs = torch.where(denom_gs > 0, denom_gs, torch.ones_like(denom_gs)) denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_gs /= denom_gs gates1_s /= denom_s
gates2_gs /= denom_gs gates2_s /= denom_s
# Calculate combine_weights and dispatch_mask # Calculate combine_weights and dispatch_mask
gates1 = torch.einsum("gs,gse->gse", gates1_gs, mask1_float) gates1 = torch.einsum("s,se->se", gates1_s, mask1_float)
gates2 = torch.einsum("gs,gse->gse", gates2_gs, mask2_float) gates2 = torch.einsum("s,se->se", gates2_s, mask2_float)
locations1_gsc = F.one_hot(locations1_gs, num_classes=capacity) locations1_sc = F.one_hot(locations1_s, num_classes=capacity)
locations2_gsc = F.one_hot(locations2_gs, num_classes=capacity) locations2_sc = F.one_hot(locations2_s, num_classes=capacity)
combine1_gsec = torch.einsum("gse,gsc->gsec", gates1, locations1_gsc) combine1_sec = torch.einsum("se,sc->sec", gates1, locations1_sc)
combine2_gsec = torch.einsum("gse,gsc->gsec", gates2, locations2_gsc) combine2_sec = torch.einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_gsec + combine2_gsec combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool() dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask return l_aux, combine_weights, dispatch_mask
...@@ -115,5 +115,5 @@ class Top2Gate(torch.nn.Module): ...@@ -115,5 +115,5 @@ class Top2Gate(torch.nn.Module):
self.wg = torch.nn.Linear(num_experts, model_dim, bias=False) self.wg = torch.nn.Linear(num_experts, model_dim, bias=False)
def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
logits = torch.einsum("gsm,me -> gse", input, self.wg.weight) logits = torch.einsum("sm,me -> se", input, self.wg.weight)
return top2gating(logits) return top2gating(logits)
...@@ -37,6 +37,10 @@ from . import version ...@@ -37,6 +37,10 @@ from . import version
class dtype: class dtype:
is_floating_point: builtins.bool is_floating_point: builtins.bool
class finfo:
def __init__(self, dtype: dtype): ...
eps: float
class layout: ... class layout: ...
strided : layout = ... strided : layout = ...
......
...@@ -23,17 +23,16 @@ else: ...@@ -23,17 +23,16 @@ else:
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501" os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ: if "OMPI_COMM_WORLD_SIZE" in os.environ:
pass # dist.init_process_group(backend=dist.Backend.MPI) dist.init_process_group(backend=dist.Backend.MPI)
def setup_module(module): def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ: if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1) dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
else:
dist.init_process_group(backend=dist.Backend.MPI)
def teardown_module(module): def teardown_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
...@@ -62,7 +61,7 @@ def test_expert_params(device): ...@@ -62,7 +61,7 @@ def test_expert_params(device):
def test_forward(device): def test_forward(device):
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(1, 4, 16, model_dim).to(device) input = torch.randn(4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False) expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix # Use identity matrix
...@@ -81,7 +80,7 @@ def test_forward_multi(device): ...@@ -81,7 +80,7 @@ def test_forward_multi(device):
num_local_experts = 4 num_local_experts = 4
model_dim = 4 model_dim = 4
num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
input = torch.randn(num_local_experts, 4, 16, model_dim).to(device) input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
experts = [] experts = []
for i in range(num_local_experts): for i in range(num_local_experts):
...@@ -106,12 +105,12 @@ class RoundRobinGate(torch.nn.Module): ...@@ -106,12 +105,12 @@ class RoundRobinGate(torch.nn.Module):
self.num_experts = num_experts self.num_experts = num_experts
def forward(self, input): def forward(self, input):
g, s, _ = input.shape s = input.shape[0]
assert s % self.num_experts == 0 assert s % self.num_experts == 0
capacity = 2 * s // self.num_experts capacity = 2 * s // self.num_experts
output = torch.zeros(g, s, self.num_experts, capacity, dtype=input.dtype, device=input.device) output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
for i in range(s): for i in range(s):
output[:, i, i % self.num_experts, i // self.num_experts] = 1.0 output[i, i % self.num_experts, i // self.num_experts] = 1.0
return 0.0, output, output.bool() return 0.0, output, output.bool()
...@@ -120,7 +119,7 @@ class RoundRobinGate(torch.nn.Module): ...@@ -120,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def test_forward_routing(device): def test_forward_routing(device):
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size() num_experts = dist.get_world_size()
input = torch.randn(1, 4, 16, model_dim).to(device) input = torch.randn(4, 16, model_dim).to(device)
gate = RoundRobinGate(model_dim, num_experts) gate = RoundRobinGate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False) expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use scaling matrix (each rank has a different scale) # Use scaling matrix (each rank has a different scale)
...@@ -130,10 +129,10 @@ def test_forward_routing(device): ...@@ -130,10 +129,10 @@ def test_forward_routing(device):
output = moe(input) output = moe(input)
assert output.shape == input.shape assert output.shape == input.shape
# Verify that each token was sent to the correct expert by checking its scale. # Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[2] t = input.shape[1]
for i in range(t): for i in range(t):
expert = i % num_experts expert = i % num_experts
assert torch.allclose(input[:, :, i] * (expert + 1), output[:, :, i]) assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
@pytest.mark.mpi @pytest.mark.mpi
...@@ -142,7 +141,7 @@ def test_backward(device): ...@@ -142,7 +141,7 @@ def test_backward(device):
loss = torch.nn.MSELoss() loss = torch.nn.MSELoss()
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(1, 4, 16, model_dim).to(device) input = torch.randn(4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False) expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix # Use identity matrix
......
...@@ -23,21 +23,21 @@ def test_create_cuda(): ...@@ -23,21 +23,21 @@ def test_create_cuda():
def do_test_forward(device): def do_test_forward(device):
torch.manual_seed(3) torch.manual_seed(3)
input = torch.randn(3, 12, 4).to(device) input = torch.randn(12, 4).to(device)
gate = Top2Gate(4, 6).to(device) gate = Top2Gate(4, 6).to(device)
capacity = 2 * 12 // 6 capacity = 2 * 12 // 6
l_aux, combine_weights, dispatch_mask = gate(input) l_aux, combine_weights, dispatch_mask = gate(input)
assert pytest.approx(l_aux.item(), 0.0283) assert pytest.approx(l_aux.item(), 0.0283)
assert combine_weights.shape == (3, 12, 6, 4) assert combine_weights.shape == (12, 6, 4)
assert dispatch_mask.shape == (3, 12, 6, 4) assert dispatch_mask.shape == (12, 6, 4)
assert torch.equal(combine_weights.bool(), dispatch_mask) assert torch.equal(combine_weights.bool(), dispatch_mask)
assert torch.all(torch.sum(dispatch_mask, axis=(1, 3)) <= capacity) assert torch.all(torch.sum(dispatch_mask, axis=(0, 2)) <= capacity)
assert torch.all(combine_weights >= 0.0) assert torch.all(combine_weights >= 0.0)
assert torch.all(combine_weights <= 1.0) assert torch.all(combine_weights <= 1.0)
weights_sum = torch.sum(combine_weights).item() weights_sum = torch.sum(combine_weights).item()
assert round(weights_sum) == pytest.approx(weights_sum) assert round(weights_sum) == pytest.approx(weights_sum)
# For this random seed, we get 36 slots filled. # For this random seed, we get 12 slots filled.
assert weights_sum == pytest.approx(36.0) assert weights_sum == pytest.approx(12.0)
def test_forward_cpu(): def test_forward_cpu():
...@@ -53,15 +53,15 @@ def test_forward_cuda(): ...@@ -53,15 +53,15 @@ def test_forward_cuda():
def test_top1s(): def test_top1s():
num_tokens = 8 num_tokens = 8
num_experts = 4 num_experts = 4
logits = torch.randn(1, num_tokens, num_experts) logits = torch.randn(num_tokens, num_experts)
l_aux, _, dispatch_mask = top2gating(logits) l_aux, _, dispatch_mask = top2gating(logits)
top1s = torch.argmax(logits, dim=2) top1s = torch.argmax(logits, dim=1)
capacity = 2 * num_tokens // num_experts capacity = 2 * num_tokens // num_experts
ce = [0] * num_experts ce = [0] * num_experts
locations = [0] * num_tokens locations = [0] * num_tokens
for i, s in enumerate(top1s[0]): for i, s in enumerate(top1s):
e = s.item() e = s.item()
loc = ce[e] loc = ce[e]
ce[e] = loc + 1 ce[e] = loc + 1
if ce[e] < capacity: if ce[e] < capacity:
assert dispatch_mask[0][i][e][loc] assert dispatch_mask[i][e][loc]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment