Unverified Commit 34df6069 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] moe: cleanup code to be more readable (#186)

parent 317c0945
...@@ -66,14 +66,6 @@ class MOELayer(Base): ...@@ -66,14 +66,6 @@ class MOELayer(Base):
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.num_local_experts = len(self.experts) self.num_local_experts = len(self.experts)
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.float(), input)
return _AllToAll.apply(self.group, dispatched_input)
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
expert_output = _AllToAll.apply(self.group, input)
return torch.einsum("sec,ecm->sm", combine_weights, expert_output)
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported" assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" assert len(input[0].shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
...@@ -83,8 +75,9 @@ class MOELayer(Base): ...@@ -83,8 +75,9 @@ class MOELayer(Base):
d_model = input[0].shape[2] d_model = input[0].shape[2]
# Reshape into S tokens by dropping sequence dimension. # Reshape into S tokens by dropping sequence dimension.
reshaped_input = input[0].reshape(-1, d_model) reshaped_input = input[0].reshape(-1, d_model)
self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input) self.l_aux, combine_weights, dispatch_mask = self.gate(reshaped_input)
dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input) dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.float(), reshaped_input)
dispatched_input = _AllToAll.apply(self.group, dispatched_input)
# Re-shape after all-to-all: ecm -> gecm # Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.world_size, self.num_local_experts, -1, d_model) dispatched_input = dispatched_input.reshape(self.world_size, self.num_local_experts, -1, d_model)
chunks = dispatched_input.chunk(self.num_local_experts, dim=1) chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
...@@ -92,7 +85,8 @@ class MOELayer(Base): ...@@ -92,7 +85,8 @@ class MOELayer(Base):
for chunk, expert in zip(chunks, self.experts): for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)] expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=1) expert_output = torch.cat(expert_outputs, dim=1)
expert_output = _AllToAll.apply(self.group, expert_output)
# Re-shape back: gecm -> ecm # Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.world_size * self.num_local_experts, -1, d_model) expert_output = expert_output.reshape(self.world_size * self.num_local_experts, -1, d_model)
combined_output = self.all_to_all_combine(combine_weights, expert_output) combined_output = torch.einsum("sec,ecm->sm", combine_weights, expert_output)
return combined_output.reshape(input[0].shape) return combined_output.reshape(input[0].shape)
...@@ -112,8 +112,8 @@ class Top2Gate(torch.nn.Module): ...@@ -112,8 +112,8 @@ class Top2Gate(torch.nn.Module):
def __init__(self, model_dim: int, num_experts: int,) -> None: def __init__(self, model_dim: int, num_experts: int,) -> None:
super().__init__() super().__init__()
self.wg = torch.nn.Linear(num_experts, model_dim, bias=False) self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
logits = torch.einsum("sm,me -> se", input, self.wg.weight) logits = self.wg(input)
return top2gating(logits) return top2gating(logits)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment