Commit 0c077a2c authored by Neel Kant's avatar Neel Kant
Browse files

Merge branch 'master' into realm-mlm

parents 150f2384 e9eef962
...@@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.min_lr <= args.lr assert args.min_lr <= args.lr
if args.save is not None: if args.save is not None:
assert args.save_interval is not None assert args.save_interval is not None
# Parameters sharing does not work with torch DDP.
if (args.num_unique_layers is not None) and (args.num_layers is not None):
assert args.num_unique_layers <= args.num_layers
assert args.num_layers % args.num_unique_layers == 0, \
'num-layers should be divisible by num-unique-layers.'
if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.'
_print_args(args) _print_args(args)
return args return args
...@@ -116,6 +124,16 @@ def _add_network_size_args(parser): ...@@ -116,6 +124,16 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None, group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--num-unique-layers', type=int, default=None,
help='Number of unique transformer layers. '
'`num-layers` should be divisible by this value.')
group.add_argument('--param-sharing-style', default='grouped',
choices=['grouped', 'spaced'],
help='Ordering of the shared parameters. For example, '
'for a `num-layers`=4 and `--num-unique-layers`=2, '
'we will have the following ordering for two unique '
'layers 1 and 2: '
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--num-attention-heads', type=int, default=None, group.add_argument('--num-attention-heads', type=int, default=None,
......
...@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
implementation is at the batch sampler level, instead of just the implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch (sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.""" sampler.
The `interleave` argument specifies how to distribute a batch. A value
of True combined with the above random sampler is equivalent to pytorch's
torch.utils.data.distributed.DistributedSampler.
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
def __init__(self, sampler, batch_size, drop_last, rank=-1, def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False): world_size=2, wrap_last=False, interleave=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last) drop_last)
if rank == -1: if rank == -1:
...@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
self.wrap_around = 0 self.wrap_around = 0
self.wrap_last = wrap_last self.wrap_last = wrap_last
self.start_iter = 0 self.start_iter = 0
self.interleave = interleave
def __iter__(self): def __iter__(self):
batch = [] batch = []
...@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch): def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch""" """extracts samples only pertaining to this worker's batch"""
if self.interleave:
return batch[self.rank:self.batch_size:self.world_size]
start = self.rank * self.batch_size // self.world_size start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end] return batch[start:end]
...@@ -29,11 +29,15 @@ from megatron.global_vars import set_global_variables ...@@ -29,11 +29,15 @@ from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False): ignore_unknown_args=False, allow_no_cuda=False):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds.""" set autoresume and random seeds.
# Make sure cuda is available. `allow_no_cuda` should not be set unless using megatron for cpu only
assert torch.cuda.is_available(), 'Megatron requires CUDA.' data processing. In general this arg should not be set unless you know
what you are doing."""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
......
...@@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule): ...@@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
def get_layer(layer_number): # Number of layers:
self.num_layers = args.num_layers
self.num_unique_layers = args.num_unique_layers
if self.num_unique_layers is None:
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = args.param_sharing_style
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, mlp_activation_func, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number) init_method, output_layer_init_method, layer_number)
# Transformer layers.
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[get_layer(i + 1) for i in range(args.num_layers)]) [build_layer(i + 1) for i in range(self.num_unique_layers)])
# Print layer ordering.
if self.num_layers != self.num_unique_layers:
if torch.distributed.get_rank() == 0:
print('> will be using the following layer ordering:')
for i in range(self.num_layers):
print(' layer id: {:3d} --> unique layer id: '
'{:3d}'.format(i, self._get_layer_index(i)),
flush=True)
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number):
if self.param_sharing_style == 'grouped':
return layer_number % self.num_unique_layers
if self.param_sharing_style == 'spaced':
return layer_number // (self.num_layers // self.num_unique_layers)
assert False, 'should not be here'
def _get_layer(self, layer_number):
return self.layers[self._get_layer_index(layer_number)]
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0] x_ = inputs[0]
for layer in layers_: for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1]) x_ = layer(x_, inputs[1])
return x_ return x_
return custom_forward return custom_forward
l = 0 l = 0
num_layers = len(self.layers) while l < self.num_layers:
while l < num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask) hidden_states, attention_mask)
...@@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule): ...@@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule):
else: else:
if get_key_value: if get_key_value:
presents = [] presents = []
for i, layer in enumerate(self.layers): for index in range(self.num_layers):
layer = self._get_layer(index)
past = None past = None
if layer_past is not None: if layer_past is not None:
past = layer_past[i] past = layer_past[index]
hidden_states = layer(hidden_states, hidden_states = layer(hidden_states,
attention_mask, attention_mask,
layer_past=past, layer_past=past,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment