Commit a0bea425 authored by rprenger's avatar rprenger
Browse files

Removing unnecessary permutes and scatter/gather

parent be136206
...@@ -94,7 +94,6 @@ class ParallelMLP(MegatronModule): ...@@ -94,7 +94,6 @@ class ParallelMLP(MegatronModule):
output, output_bias = self.dense_4h_to_h(intermediate_parallel) output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias return output, output_bias
class SwitchMLP(MegatronModule): class SwitchMLP(MegatronModule):
""" """
Routes input to one of N MLP "experts" Routes input to one of N MLP "experts"
...@@ -106,7 +105,7 @@ class SwitchMLP(MegatronModule): ...@@ -106,7 +105,7 @@ class SwitchMLP(MegatronModule):
self.experts = torch.nn.ModuleList() self.experts = torch.nn.ModuleList()
for i in range(args.num_experts): for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method)) self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states): def forward(self, hidden_states):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
b = hidden_states.size(0) b = hidden_states.size(0)
...@@ -115,30 +114,31 @@ class SwitchMLP(MegatronModule): ...@@ -115,30 +114,31 @@ class SwitchMLP(MegatronModule):
route = self.router(hidden_states) route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2) route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2) max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
# TODO (rprenger) TODO this could be made easier to read # TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h]. # Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently # Each vector could be routed differently
hidden_states = hidden_states.permute(2,0,1).view(hidden_states.size(2), -1).permute(1,0).unsqueeze(1) hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.permute(2,0,1).view(max_prob.size(2), -1).permute(1,0).unsqueeze(1) max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) max_ind = max_ind.view(-1) # [b*s]
output_total = torch.empty_like(hidden_states) output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized #TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts): for expert_num, expert in enumerate(self.experts):
ind = (max_ind==expert_num).nonzero().unsqueeze(2).repeat(1,1, h) local_indices = (max_ind == expert_num).nonzero()
hidden = torch.gather(hidden_states, 0, ind) hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden) output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output) output_bias = output_bias.expand_as(output)
output_total.scatter_(0, ind, output) output_total[local_indices,:] = output
output_bias_total.scatter_(0, ind, output_bias) output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob output_bias_total = output_bias_total*max_prob
output_total = output_total.permute(2,0,1).view(h, b, s).permute(1,2,0) output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.permute(2,0,1).view(h, b, s).permute(1,2,0) output_bias_total = output_bias_total.view(b, s, h)
return output_total, output_bias_total return output_total, output_bias_total
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment