"vscode:/vscode.git/clone" did not exist on "e0f4c0ecf7c08f76081053fd0cdc5b0c3e893cd8"
Commit feea48cd authored by rprenger's avatar rprenger
Browse files

Merging with main

parents 8694c7b0 0be40526
...@@ -256,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -256,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride) partition_dim=0, stride=stride)
if bias: if bias:
if args.use_cpu_initialization: if args.use_cpu_initialization:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
...@@ -286,7 +286,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -286,7 +286,7 @@ class ColumnParallelLinear(torch.nn.Module):
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
...@@ -316,8 +316,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -316,8 +316,8 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
skip_bias_add: This was added to enable performance optimations where bias skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. we skip can be fused with other elementwise operations. We skip
adding bias but instead return it. adding bias but instead return it.
""" """
......
...@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim ...@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim
def _reduce(input_): def _reduce(input_):
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1: if get_tensor_model_parallel_world_size()==1:
......
...@@ -47,9 +47,18 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -47,9 +47,18 @@ def init_checkpointed_activations_memory_buffer():
per_layer = args.micro_batch_size * args.max_position_embeddings * \ per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \ num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
'number of layers is not divisible by checkpoint-num-layers' if args.virtual_pipeline_model_parallel_size is not None:
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers num_layers = num_layers // args.virtual_pipeline_model_parallel_size
if args.activations_checkpoint_method == 'uniform':
assert num_layers % args.activations_checkpoint_num_layers == 0, \
'total number of layers is not divisible by checkpoint-chunk_size'
num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
elif args.activations_checkpoint_method == 'block':
assert args.activations_checkpoint_num_layers <= num_layers, \
'total number of layers is fewer than the number of layers to checkpoint'
num_checkpointer_layers = args.activations_checkpoint_num_layers
numel = per_layer * num_checkpointer_layers numel = per_layer * num_checkpointer_layers
dtype = torch.half dtype = torch.half
if not args.fp16: if not args.fp16:
......
...@@ -100,10 +100,12 @@ def get_megatron_optimizer(model): ...@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
args.clip_grad, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.bf16, args.bf16,
grad_scaler) grad_scaler)
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad, return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad) params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
...@@ -68,7 +68,9 @@ class MegatronOptimizer(ABC): ...@@ -68,7 +68,9 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad): params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
"""Input optimizer is the base optimizer for example Adam.""" """Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.' assert self.optimizer, 'no optimizer is provided.'
...@@ -76,7 +78,11 @@ class MegatronOptimizer(ABC): ...@@ -76,7 +78,11 @@ class MegatronOptimizer(ABC):
self.clip_grad = clip_grad self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
if self.use_contiguous_buffers_in_local_ddp:
assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad"
def get_parameters(self): def get_parameters(self):
params = [] params = []
...@@ -187,11 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -187,11 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
""" """
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler): params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__( super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self.bf16 = bf16 self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
...@@ -282,9 +289,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -282,9 +289,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups.""" float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for group in self.float16_groups: for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups: for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
...@@ -305,12 +317,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -305,12 +317,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
else: else:
if model_param.grad is not None: if model_param.grad is not None:
main_param.grad = model_param.grad.float() main_param.grad = model_param.grad.float()
# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None
if self.params_have_main_grad and \
not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
# For fp32 grads, we need to reset the grads to main grad. # For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad: if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups: for model_group in self.fp32_from_fp32_groups:
for model_param in model_group: for model_param in model_group:
model_param.grad = model_param.main_grad model_param.grad = model_param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
main_grads = [] main_grads = []
...@@ -464,11 +490,12 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -464,11 +490,12 @@ class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad): params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
super(FP32Optimizer, self).__init__( super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self._scale = torch.cuda.FloatTensor([1.0]) self._scale = torch.cuda.FloatTensor([1.0])
...@@ -495,6 +522,12 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -495,6 +522,12 @@ class FP32Optimizer(MegatronOptimizer):
for param in param_group['params']: for param in param_group['params']:
param.grad = param.main_grad param.grad = param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_local_ddp:
param.main_grad = None
# Clip gradients. # Clip gradients.
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
......
...@@ -22,7 +22,9 @@ from megatron import mpu ...@@ -22,7 +22,9 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False): use_ring_exchange=False, tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
next rank. next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange() use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used. API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) (tensor_recv_prev, tensor_recv_next)
""" """
...@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if tensor_shape is None:
if args.scatter_gather_tensors_in_pipeline: tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size() mpu.get_tensor_model_parallel_world_size()
else: else:
...@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
dtype = args.params_dtype dtype = args.params_dtype
if args.fp32_residual_connection: if args.fp32_residual_connection:
dtype = torch.float dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape, tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
if recv_next: if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape, tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization. # Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
...@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch.cuda.synchronize() torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
...@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None): def recv_forward(tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
...@@ -135,7 +157,11 @@ def recv_forward(timers=None): ...@@ -135,7 +157,11 @@ def recv_forward(timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
...@@ -158,8 +184,11 @@ def recv_backward(timers=None): ...@@ -158,8 +184,11 @@ def recv_backward(timers=None):
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, timers=None): def send_forward(output_tensor, timers=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
if timers is not None: if timers is not None:
timers('forward-send').start() timers('forward-send').start()
...@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None): ...@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
......
...@@ -31,6 +31,9 @@ def get_forward_backward_func(): ...@@ -31,6 +31,9 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
...@@ -191,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -191,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \ if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]): len(output_tensors[model_chunk_id]):
...@@ -202,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -202,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced) input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor return output_tensor
def backward_step_helper(microbatch_id): def backward_step_helper(microbatch_id):
...@@ -228,7 +237,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -228,7 +237,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(timers)) p2p_communication.recv_forward(timers=timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
...@@ -262,7 +271,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -262,7 +271,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_forward_recv_forward( p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers) output_tensor, recv_prev=recv_prev, timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -340,7 +349,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -340,7 +349,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers)) p2p_communication.recv_backward(timers=timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
...@@ -352,7 +361,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -352,7 +361,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers)) input_tensor_grad, recv_next=recv_next, timers=timers))
return losses_reduced return losses_reduced
...@@ -380,25 +389,30 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -380,25 +389,30 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
input_tensors = [] # Input, output tensors only need to be saved when doing backward passes
output_tensors = [] input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = [] losses_reduced = []
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers=timers)
input_tensors.append(input_tensor) if not forward_only:
output_tensors.append(output_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers=timers)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -407,22 +421,24 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -407,22 +421,24 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
if forward_only: if forward_only:
p2p_communication.send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers=timers)
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, p2p_communication.send_forward_recv_backward(output_tensor,
timers) timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the # Add input_tensor and output_tensor to end of list.
# start of the list for backward pass. input_tensors.append(input_tensor)
input_tensors.append(input_tensor) output_tensors.append(output_tensor)
output_tensors.append(output_tensor)
if forward_only: # Pop input_tensor and output_tensor from the start of the list for
if not last_iteration: # the backward pass.
input_tensor = p2p_communication.recv_forward(timers) input_tensor = input_tensors.pop(0)
else: output_tensor = output_tensors.pop(0)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
...@@ -430,11 +446,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -430,11 +446,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers=timers)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_backward_recv_forward( p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers) input_tensor_grad, timers=timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -442,12 +458,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -442,12 +458,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers) output_tensor_grad = p2p_communication.recv_backward(timers=timers)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers=timers)
return losses_reduced return losses_reduced
...@@ -47,9 +47,7 @@ from megatron.utils import check_adlr_autoresume_termination ...@@ -47,9 +47,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining from megatron.schedules import get_forward_backward_func
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -98,7 +96,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -98,7 +96,7 @@ def pretrain(train_valid_test_dataset_provider,
# This will be closer to what scheduler will see (outside of # This will be closer to what scheduler will see (outside of
# image ... launches. # image ... launches.
global _TRAIN_START_TIME global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME]) start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor, torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN) op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item() _TRAIN_START_TIME = start_time_tensor.item()
...@@ -255,7 +253,7 @@ def get_model(model_provider_func): ...@@ -255,7 +253,7 @@ def get_model(model_provider_func):
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = [LocalDDP(model_module, model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
return model return model
...@@ -353,26 +351,20 @@ def train_step(forward_step_func, data_iterator, ...@@ -353,26 +351,20 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers() timers = get_timers()
# Set grad to zero. # Set grad to zero.
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp: if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
for partition in model: for partition in model:
partition.zero_grad_buffer() partition.zero_grad_buffer()
else: optimizer.zero_grad()
optimizer.zero_grad()
forward_backward_func = get_forward_backward_func()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
...@@ -419,6 +411,10 @@ def train_step(forward_step_func, data_iterator, ...@@ -419,6 +411,10 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# Empty unused memory
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True): if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches. # Average loss across microbatches.
loss_reduced = {} loss_reduced = {}
...@@ -531,11 +527,28 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -531,11 +527,28 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if args.log_timers_to_tensorboard: if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration, timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations) normalizer=total_iterations)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
iteration,
)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed() elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0: if writer:
if args.log_timers_to_tensorboard: if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time', writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration) elapsed_time_per_iteration, iteration)
...@@ -705,17 +718,15 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -705,17 +718,15 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
if mpu.get_pipeline_model_parallel_world_size() > 1: forward_backward_func = get_forward_backward_func()
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
loss_dicts = forward_backward_func( loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None, forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True) timers=None, forward_only=True)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True): if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes. # Reduce across processes.
for loss_dict in loss_dicts: for loss_dict in loss_dicts:
...@@ -748,7 +759,7 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -748,7 +759,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item())) ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl) string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and is_last_rank(): if writer:
writer.add_scalar('{} validation'.format(key), writer.add_scalar('{} validation'.format(key),
total_loss_dict[key].item(), total_loss_dict[key].item(),
iteration) iteration)
...@@ -787,10 +798,9 @@ def build_train_valid_test_data_iterators( ...@@ -787,10 +798,9 @@ def build_train_valid_test_data_iterators(
'only backward compatiblity support for iteration-based training' 'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0: if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \ if args.train_samples is None:
'only backward compatiblity support for iteration-based training' args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.eval_iters * args.global_batch_size
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
......
## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
## Retriever Training
#### Unsupervised pretraining
1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
<pre>
python tools/preprocess_data.py \
--input /path/to/corpus.json \
--json-keys text title \
--split-sentences \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file /path/to/vocab.txt \
--output-prefix corpus_indexed \
--workers 10
</pre>
2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training.
3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
#### Supervised finetuning
1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
## Reader Training
The reader component will be available soon.
...@@ -244,6 +244,9 @@ def normalize_question(question): ...@@ -244,6 +244,9 @@ def normalize_question(question):
question = question[:-1] question = question[:-1]
return question return question
# The following class reads the datasets for training retriever as
# prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
class NQSupervisedDataset(OpenRetrievalAbstractDataset): class NQSupervisedDataset(OpenRetrievalAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length, \ def __init__(self, name, datapaths, tokenizer, max_seq_length, \
......
...@@ -33,6 +33,28 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider ...@@ -33,6 +33,28 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
# input_ is a 2D tensor
def check_and_append_tensor_for_gather(group, rank, world_size, input_):
# gather the size of the first dimension of the tensor from all ranks
current_length = input_.size()[0]
first_dim = torch.tensor([[current_length]],
device=torch.cuda.current_device())
input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
input_list[rank].copy_(first_dim)
torch.distributed.all_gather(input_list, first_dim, group=group)
all_input_list = torch.cat(input_list, dim=0).contiguous()
max_length = torch.max(all_input_list)
# if the size are different than the max, extend the tensor
# accordingly
if max_length > current_length:
padding=tuple([0] * (input_.dim() * 2 - 1)) + \
tuple([max_length - current_length])
input_ = F.pad(input=input_, pad=padding)
return input_
def orqa(Dataset): def orqa(Dataset):
def cross_entropy_forward_step(batch, model): def cross_entropy_forward_step(batch, model):
...@@ -47,6 +69,8 @@ def orqa(Dataset): ...@@ -47,6 +69,8 @@ def orqa(Dataset):
except BaseException: except BaseException:
batch_ = batch batch_ = batch
group, rank, world_size = get_group_world_size_rank()
query_tokens, query_mask, query_types, query_pad_mask, \ query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \ context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \ neg_context_tokens, neg_context_mask, neg_context_types, \
...@@ -61,6 +85,14 @@ def orqa(Dataset): ...@@ -61,6 +85,14 @@ def orqa(Dataset):
query_list.append(tokenizer.decode(query_tokens[i].tolist())) query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist())) context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens is not None:
neg_context_tokens = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_tokens)
neg_context_mask = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_mask)
neg_context_types = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_types)
if neg_context_tokens is not None: if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens]) context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask]) context_mask = torch.cat([context_mask, neg_context_mask])
...@@ -70,7 +102,6 @@ def orqa(Dataset): ...@@ -70,7 +102,6 @@ def orqa(Dataset):
output_tensor = model(query_tokens, query_mask, output_tensor = model(query_tokens, query_mask,
query_types, context_tokens, query_types, context_tokens,
context_mask, context_types) context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens) return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
......
...@@ -20,7 +20,7 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for ...@@ -20,7 +20,7 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique. 4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
# Prepare the data for GPT-2 training: # Prepare the data for GPT training:
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards. 1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
``` ```
...@@ -50,7 +50,7 @@ shuf <cleaned deduped data file> -o train_data.json ...@@ -50,7 +50,7 @@ shuf <cleaned deduped data file> -o train_data.json
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command. To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
``` ```
python filter_ngrams.py --tasks <name of he task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset> python filter_ngrams.py --tasks <name of the task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
``` ```
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments. We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment