Commit 0c077a2c authored by Neel Kant's avatar Neel Kant
Browse files

Merge branch 'master' into realm-mlm

parents 150f2384 e9eef962
......@@ -89,6 +89,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Parameters sharing does not work with torch DDP.
if (args.num_unique_layers is not None) and (args.num_layers is not None):
assert args.num_unique_layers <= args.num_layers
assert args.num_layers % args.num_unique_layers == 0, \
'num-layers should be divisible by num-unique-layers.'
if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.'
_print_args(args)
return args
......@@ -116,6 +124,16 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--num-unique-layers', type=int, default=None,
help='Number of unique transformer layers. '
'`num-layers` should be divisible by this value.')
group.add_argument('--param-sharing-style', default='grouped',
choices=['grouped', 'spaced'],
help='Ordering of the shared parameters. For example, '
'for a `num-layers`=4 and `--num-unique-layers`=2, '
'we will have the following ordering for two unique '
'layers 1 and 2: '
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--num-attention-heads', type=int, default=None,
......
......@@ -80,10 +80,20 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler."""
sampler.
The `interleave` argument specifies how to distribute a batch. A value
of True combined with the above random sampler is equivalent to pytorch's
torch.utils.data.distributed.DistributedSampler.
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False):
world_size=2, wrap_last=False, interleave=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last)
if rank == -1:
......@@ -95,6 +105,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
self.wrap_around = 0
self.wrap_last = wrap_last
self.start_iter = 0
self.interleave = interleave
def __iter__(self):
batch = []
......@@ -130,6 +141,8 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
if self.interleave:
return batch[self.rank:self.batch_size:self.world_size]
start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end]
......@@ -29,11 +29,15 @@ from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
ignore_unknown_args=False, allow_no_cuda=False):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing."""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
......
......@@ -360,34 +360,60 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
def get_layer(layer_number):
# Number of layers:
self.num_layers = args.num_layers
self.num_unique_layers = args.num_unique_layers
if self.num_unique_layers is None:
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
self.param_sharing_style = args.param_sharing_style
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer(i + 1) for i in range(args.num_layers)])
[build_layer(i + 1) for i in range(self.num_unique_layers)])
# Print layer ordering.
if self.num_layers != self.num_unique_layers:
if torch.distributed.get_rank() == 0:
print('> will be using the following layer ordering:')
for i in range(self.num_layers):
print(' layer id: {:3d} --> unique layer id: '
'{:3d}'.format(i, self._get_layer_index(i)),
flush=True)
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number):
if self.param_sharing_style == 'grouped':
return layer_number % self.num_unique_layers
if self.param_sharing_style == 'spaced':
return layer_number // (self.num_layers // self.num_unique_layers)
assert False, 'should not be here'
def _get_layer(self, layer_number):
return self.layers[self._get_layer_index(layer_number)]
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0]
for layer in layers_:
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
return x_
return custom_forward
l = 0
num_layers = len(self.layers)
while l < num_layers:
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
......@@ -414,10 +440,11 @@ class ParallelTransformer(MegatronModule):
else:
if get_key_value:
presents = []
for i, layer in enumerate(self.layers):
for index in range(self.num_layers):
layer = self._get_layer(index)
past = None
if layer_past is not None:
past = layer_past[i]
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
layer_past=past,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment