Commit 31d39ec0 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main_retriver_merge' into main_retriver_merge_ict_eval

parents 612f438a 9dc111cc
...@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size = args.pipeline_model_parallel_size * \ model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\ assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size, 'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size) args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size args.data_parallel_size = args.world_size // model_parallel_size
...@@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
print('setting global batch size to {}'.format( print('setting global batch size to {}'.format(
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when ' \
'using interleaved schedule'
else:
args.virtual_pipeline_model_parallel_size = None
# Parameters dtype. # Parameters dtype.
args.params_dtype = torch.float args.params_dtype = torch.float
...@@ -202,7 +214,23 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -202,7 +214,23 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion:
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.')
# Load scaled_masked_softmax_fusion_kernels # Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion: if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel() fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
...@@ -478,9 +506,9 @@ def _add_checkpointing_args(parser): ...@@ -478,9 +506,9 @@ def _add_checkpointing_args(parser):
help='Output directory to save checkpoints to.') help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=None, group.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.') help='Number of iterations between checkpoint saves.')
group.add_argument('--no-save-optim', action='store_true', group.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.') help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true', group.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.') help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None, group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
...@@ -541,6 +569,8 @@ def _add_distributed_args(parser): ...@@ -541,6 +569,8 @@ def _add_distributed_args(parser):
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
...@@ -548,6 +578,9 @@ def _add_distributed_args(parser): ...@@ -548,6 +578,9 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline')
group.add_argument('--local_rank', type=int, default=None, group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.') help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False, group.add_argument('--lazy-mpu-init', type=bool, required=False,
......
...@@ -21,12 +21,12 @@ import sys ...@@ -21,12 +21,12 @@ import sys
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import (get_args, from megatron import (get_args,
mpu, mpu,
print_rank_0, print_rank_0,
update_num_microbatches) update_num_microbatches,
utils)
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
...@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
...@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0 state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
...@@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir) tracker_filename = get_checkpoint_tracker_filename(load_dir)
...@@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
# Model. # Model.
model.load_state_dict(state_dict['model'], strict=strict) if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed # Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
...@@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
np.random.set_state(state_dict['np_rng_state']) np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state']) torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states( mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states']) state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent ' 'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, ' 'attempting to load the rng state, '
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
...@@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, ...@@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_context_model=False,
args = get_args() args = get_args()
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load load_path = args.load if from_realm_chkpt else args.ict_load
......
...@@ -133,7 +133,8 @@ def _initialize_distributed(): ...@@ -133,7 +133,8 @@ def _initialize_distributed():
print('model parallel is already initialized') print('model parallel is already initialized')
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size) args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
def _init_autoresume(): def _init_autoresume():
......
...@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert ( assert (
self.scale is None or softmax_in_fp32 self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
assert input.dim() == 4
data_size = input.size() data_size = input.size()
query_seq_len = data_size[-2] query_seq_len = data_size[-2]
key_seq_len = data_size[-1] key_seq_len = data_size[-1]
assert input.dim() == 4 attn_batch_size = data_size[0] * data_size[1]
# invoke custom kernel # constraints on various tensor dimensions to enable warp based
if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ # optimization and upper triangular optimization (for causal mask)
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel
if self.input_in_fp16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal: if self.attn_mask_type == AttnMaskType.causal:
......
...@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module): ...@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage(ignore_virtual=True):
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage(ignore_virtual=True):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last ' raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
......
...@@ -552,7 +552,27 @@ class ParallelTransformer(MegatronModule): ...@@ -552,7 +552,27 @@ class ParallelTransformer(MegatronModule):
layer_number, layer_number,
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
...@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank ...@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
...@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker ...@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None ...@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -48,7 +51,8 @@ def is_unitialized(): ...@@ -48,7 +51,8 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1): pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size num_data_parallel_groups = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank(): ...@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage(): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if get_virtual_pipeline_model_parallel_world_size() is not None and \
get_virtual_pipeline_model_parallel_rank() != 0:
return False
return get_pipeline_model_parallel_rank() == 0 return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(): def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise.""" """Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = \
get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and \
get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1):
return False
return get_pipeline_model_parallel_rank() == ( return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the tensor model parallel group."""
...@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank(): ...@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size = get_tensor_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0] return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank(): def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
...@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank(): ...@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank(): def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
...@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank(): ...@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size = get_pipeline_model_parallel_world_size() world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size(): def get_data_parallel_world_size():
"""Return world size for the data parallel group.""" """Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group()) return torch.distributed.get_world_size(group=get_data_parallel_group())
......
...@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler ...@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(module): def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
...@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module): ...@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module):
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules(): for module in modules:
if isinstance(module_, LayerNorm): for module_ in module.modules():
no_weight_decay_params['params'].extend( if isinstance(module_, LayerNorm):
[p for p in list(module_._parameters.values()) no_weight_decay_params['params'].extend(
if p is not None]) [p for p in list(module_._parameters.values())
else: if p is not None])
weight_decay_params['params'].extend( else:
[p for n, p in list(module_._parameters.items()) weight_decay_params['params'].extend(
if p is not None and n != 'bias']) [p for n, p in list(module_._parameters.items())
no_weight_decay_params['params'].extend( if p is not None and n != 'bias'])
[p for n, p in list(module_._parameters.items()) no_weight_decay_params['params'].extend(
if p is not None and n == 'bias']) [p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
import torch
from megatron import get_args
from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None, use_ring_exchange=False):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
def recv_backward(timers=None, use_ring_exchange=False):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor, timers=None, use_ring_exchange=False):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
use_ring_exchange=use_ring_exchange)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
use_ring_exchange=True)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
use_ring_exchange=True)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
use_ring_exchange=True)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers = get_timers()
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
return output_tensor
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args = get_args()
timers = get_timers()
timers('backward-compute').start()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
timers('backward-compute').stop()
return input_tensor_grad
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert len(model) == 1
model = model[0]
losses_reduced = []
for i in range(get_num_microbatches()):
input_tensor, output_tensor_grad = None, None
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
return losses_reduced
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(optimizer,
input_tensor,
output_tensor,
output_tensor_grad)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(timers, use_ring_exchange=True))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if mpu.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if mpu.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
forward=True)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
input_tensor, output_tensor_grad = \
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers, use_ring_exchange=True))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers))
return losses_reduced
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers,
forward_only):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers = get_timers()
assert len(model) == 1
model = model[0]
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall.
if i == (num_warmup_microbatches - 1):
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
p2p_communication.send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Barrier before first receive to measure forward stall.
if num_warmup_microbatches == 0:
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor,
timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers)
return losses_reduced
This diff is collapsed.
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import sys import sys
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
...@@ -26,11 +27,25 @@ from megatron import get_args ...@@ -26,11 +27,25 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model): def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
# Remove duplicate params. # Remove duplicate params.
...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args() args = get_args()
autoresume = get_adlr_autoresume() autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
......
...@@ -38,7 +38,7 @@ def model_provider(): ...@@ -38,7 +38,7 @@ def model_provider():
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0 num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
...@@ -51,6 +51,17 @@ def model_provider(): ...@@ -51,6 +51,17 @@ def model_provider():
else: else:
model = BertModelIntermediateStage( model = BertModelIntermediateStage(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = BertModel( model = BertModel(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
= get_batch(data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
if not args.bert_binary_head: if not args.bert_binary_head:
......
...@@ -35,8 +35,8 @@ def model_provider(): ...@@ -35,8 +35,8 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
...@@ -46,6 +46,17 @@ def model_provider(): ...@@ -46,6 +46,17 @@ def model_provider():
else: else:
model = GPTModelIntermediateStage( model = GPTModelIntermediateStage(
num_tokentypes=0) num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = GPTModel(num_tokentypes=0, parallel_output=True) model = GPTModel(num_tokentypes=0, parallel_output=True)
......
...@@ -92,7 +92,9 @@ def main(): ...@@ -92,7 +92,9 @@ def main():
"""Main program.""" """Main program."""
initialize_megatron(extra_args_provider=add_text_generate_args, initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(model_provider) model = get_model(model_provider)
......
...@@ -200,6 +200,8 @@ def main(): ...@@ -200,6 +200,8 @@ def main():
'micro_batch_size': 1, 'micro_batch_size': 1,
'no_load_optim': True, 'no_load_optim': True,
'no_load_rng': True, 'no_load_rng': True,
'no_save_optim': True,
'no_save_rng': True,
'save_interval': 1}) 'save_interval': 1})
args = get_args() args = get_args()
...@@ -240,6 +242,11 @@ def main(): ...@@ -240,6 +242,11 @@ def main():
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.tensor_model_parallel_size): for rank in range(args.tensor_model_parallel_size):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
mpu.initialize.set_tensor_model_parallel_rank(rank) mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
model_ = get_model(model_type) model_ = get_model(model_type)
......
...@@ -44,3 +44,12 @@ python remove_group_duplicates.py <file containing simialr documents> <cleaned d ...@@ -44,3 +44,12 @@ python remove_group_duplicates.py <file containing simialr documents> <cleaned d
shuf <cleaned deduped data file> -o train_data.json shuf <cleaned deduped data file> -o train_data.json
``` ```
# Deduplicating ngrams
To deduplicate the downstream tasks from the training dataset, we run the following command.
```
python filter_ngrams.py <down stream task dataset> <training dataset to deduplicate> <output training dataset>
```
We use 13-grams for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deduplicate downstream tasks from training dataset. 13-grams have been used.
All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
from functools import partial
import json
import multiprocessing
import nltk
import re
import string
import sys
import time
def get_words(text):
# get all the lowercase words from text
words, positions = [], []
for match in re.finditer(r'\w+', text.lower()):
words.append(match.group(0))
positions.append(match.start())
return words, positions
def free_ngram(line, ngrams, ngram_size, filter_text_len,
splits_count, split_window_each_size):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson['text']]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
not_ngram_free = True
punctuations = ".!?"
# find n-grams
for i in range(len(words) - ngram_size + 1):
seq = " ".join(words[i:i+ngram_size])
if seq in ngrams:
# splits the text
# first part of the text
pos = positions[i] - split_window_each_size
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
pos = positions[i] + split_window_each_size
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
# first part of ngrams free
if len(text_first) > filter_text_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > filter_text_len:
text_buf.append(text_second)
not_ngram_free = False
break
# text are ngram free
if not_ngram_free:
text_buf_ngram_free.append(text)
return text_buf_ngram_free
if __name__ == '__main__':
print('finding possible duplicate content ...')
main_file = sys.argv[1] # lambada file
dedup_file = sys.argv[2] # Book corpus
output_file = sys.argv[3] #Filtered book corpus
ngrams = {}
id_prefix = "lambada"
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
ngram_size = 13
filter_text_len = 200
splits_count = 10
split_window_each_size = 200
print('Reading file {} and computing ngrams'.format(main_file))
with open(main_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
words, positions = get_words(myjson['text'])
for i in range(len(words) - ngram_size+1):
seq = " ".join(words[i:i+ngram_size])
if seq not in ngrams:
ngrams[seq] = positions[i]
except Exception as e:
print('Error:', e)
print("ngrams size {}".format(len(ngrams)))
print('Reading file {} and deduping n-grams'.format(dedup_file))
counter = 0
start_time = time.time()
out_f = open(output_file, 'wb')
splitted, ignored, split_mt_thld = 0, 0, 0
# Setup multi-processing.
num_workers = 40
fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, ngrams=ngrams, ngram_size=ngram_size,
filter_text_len=filter_text_len, splits_count=splits_count,
split_window_each_size=split_window_each_size)
free_ngrams = pool.imap(free_ngram_x, fin, 25)
for text_buf_ngram_free in free_ngrams:
counter += 1
try:
if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1)
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \
+ '-{:010d}'.format(int(i))
outjson = json.dumps({"text":text_buf_ngram_free[i],
id_prefix+"_split_id":split_id_string},
ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print("Deduped file written to: {}".format(output_file), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}".\
format(counter, splitted, ignored, split_mt_thld), flush=True)
print('done :-)')
...@@ -19,7 +19,7 @@ import sys ...@@ -19,7 +19,7 @@ import sys
def is_similar(jaccard_similarity): def is_similar(jaccard_similarity):
return (js >= 0.9) return (jaccard_similarity >= 0.7)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment