"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "d520d24fdf2eb7c2e76aa5ca3020cbfd07c42910"
Commit 80f90dcd authored by Mohammad's avatar Mohammad
Browse files

added parameters sharing

parent ad299490
...@@ -116,6 +116,16 @@ def _add_network_size_args(parser): ...@@ -116,6 +116,16 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None, group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--num-unique-layers', type=int, default=None,
help='Number of unique transformer layers. '
'`num-layers` should be divisible by this value.')
group.add_argument('--param-sharing-style', default='grouped',
choices=['grouped', 'space'],
help='Ordering of the shared parameters. For example, '
'for a `num-layers`=4 and `--num-unique-layers`=2, '
'we will have the following ordering for two unique '
'layers 1 and 2: '
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--num-attention-heads', type=int, default=None, group.add_argument('--num-attention-heads', type=int, default=None,
......
...@@ -360,34 +360,61 @@ class ParallelTransformer(MegatronModule): ...@@ -360,34 +360,61 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
def get_layer(layer_number): # Number of layers:
self.num_layers = args.num_layers
self.num_unique_layers = args.num_unique_layers
if self.num_unique_layers is None:
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = args.param_sharing_style
assert self.param_sharing_style in ['grouped', 'spaced']
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, mlp_activation_func, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number) init_method, output_layer_init_method, layer_number)
# Transformer layers.
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[get_layer(i + 1) for i in range(args.num_layers)]) [build_layer(i + 1) for i in range(self.num_unique_layers)])
# Print layer ordering.
if self.num_layers != self.num_unique_layers:
if torch.distributed.get_rank() == 0:
print('> will be using the following layer ordering:')
for i in range(self.num_layers):
print(' layer: {:3d} --> unique layer: {:3d}'.format(
i, self._get_layer_index(i)), flush=True)
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number):
if self.param_sharing_style == 'grouped':
return layer_number % self.num_unique_layers
if self.param_sharing_style == 'spaced':
return layer_number // (self.num_layers // self.num_unique_layers)
assert False, 'should not be here'
def _get_layer(self, layer_number):
return self.layers[self._get_layer_index(layer_number)]
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
layers_ = self.layers[start:end] layers_ = self.layers[start:end]
x_ = inputs[0] x_ = inputs[0]
for layer in layers_: for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1]) x_ = layer(x_, inputs[1])
return x_ return x_
return custom_forward return custom_forward
l = 0 l = 0
num_layers = len(self.layers) while l < self.num_layers:
while l < num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask) hidden_states, attention_mask)
...@@ -414,10 +441,11 @@ class ParallelTransformer(MegatronModule): ...@@ -414,10 +441,11 @@ class ParallelTransformer(MegatronModule):
else: else:
if get_key_value: if get_key_value:
presents = [] presents = []
for i, layer in enumerate(self.layers): for index in range(self.num_layers):
layer = self._get_layer(index)
past = None past = None
if layer_past is not None: if layer_past is not None:
past = layer_past[i] past = layer_past[index]
hidden_states = layer(hidden_states, hidden_states = layer(hidden_states,
attention_mask, attention_mask,
layer_past=past, layer_past=past,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment