Commit 53f3efc4 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

Merge branch 'switch' into 'main'

Switch Transformer

See merge request ADLR/megatron-lm!373
parents 9c5a830f 458d7785
...@@ -365,7 +365,8 @@ def _add_network_size_args(parser): ...@@ -365,7 +365,8 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false', group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.', help='Disable BERT binary head.',
dest='bert_binary_head') dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser return parser
......
...@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule): ...@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule):
output, output_bias = self.dense_4h_to_h(intermediate_parallel) output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias return output, output_bias
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
return output_total, output_bias_total
class ParallelAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
...@@ -479,8 +526,10 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -479,8 +526,10 @@ class ParallelTransformerLayer(MegatronModule):
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm)
# MLP # MLP
self.mlp = ParallelMLP(init_method, if args.num_experts is not None:
output_layer_init_method) self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment