Commit e9eef962 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'master_params_sharing' into 'master'

Parameters sharing

See merge request ADLR/megatron-lm!82
parents 3ee811be c9c69c1e
......@@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Parameters sharing does not work with torch DDP.
if (args.num_unique_layers is not None) and (args.num_layers is not None):
assert args.num_unique_layers <= args.num_layers
assert args.num_layers % args.num_unique_layers == 0, \
'num-layers should be divisible by num-unique-layers.'
if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.'
_print_args(args)
return args
......@@ -116,6 +124,16 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--num-unique-layers', type=int, default=None,
help='Number of unique transformer layers. '
'`num-layers` should be divisible by this value.')
group.add_argument('--param-sharing-style', default='grouped',
choices=['grouped', 'spaced'],
help='Ordering of the shared parameters. For example, '
'for a `num-layers`=4 and `--num-unique-layers`=2, '
'we will have the following ordering for two unique '
'layers 1 and 2: '
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--num-attention-heads', type=int, default=None,
......
......@@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
def get_layer(layer_number):
# Number of layers:
self.num_layers = args.num_layers
self.num_unique_layers = args.num_unique_layers
if self.num_unique_layers is None:
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = args.param_sharing_style
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer(i + 1) for i in range(args.num_layers)])
[build_layer(i + 1) for i in range(self.num_unique_layers)])
# Print layer ordering.
if self.num_layers != self.num_unique_layers:
if torch.distributed.get_rank() == 0:
print('> will be using the following layer ordering:')
for i in range(self.num_layers):
print(' layer id: {:3d} --> unique layer id: '
'{:3d}'.format(i, self._get_layer_index(i)),
flush=True)
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number):
if self.param_sharing_style == 'grouped':
return layer_number % self.num_unique_layers
if self.param_sharing_style == 'spaced':
return layer_number // (self.num_layers // self.num_unique_layers)
assert False, 'should not be here'
def _get_layer(self, layer_number):
return self.layers[self._get_layer_index(layer_number)]
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0]
for layer in layers_:
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
return x_
return custom_forward
l = 0
num_layers = len(self.layers)
while l < num_layers:
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
......@@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule):
else:
if get_key_value:
presents = []
for i, layer in enumerate(self.layers):
for index in range(self.num_layers):
layer = self._get_layer(index)
past = None
if layer_past is not None:
past = layer_past[i]
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
layer_past=past,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment