Commit 6b5ae488 authored by rprenger's avatar rprenger
Browse files

Fixing merge conflicts with main

parents 3d718bfc 3860e995
...@@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \ assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')' 'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
_print_args(args) _print_args(args)
return args return args
......
...@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed) _set_random_seed(args.seed)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
args.use_cpu_initialization=True args.use_cpu_initialization=True
...@@ -78,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -78,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init() finish_mpu_init()
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
...@@ -226,10 +226,24 @@ def write_args_to_tensorboard(): ...@@ -226,10 +226,24 @@ def write_args_to_tensorboard():
global_step=args.iteration) global_step=args.iteration)
def _initialize_mem_buffs(): def _set_jit_fusion_options():
"""Initialize manually allocated static memory.""" """Set PyTorch JIT layer fusion options."""
args = get_args() # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
mpu.init_checkpointed_activations_memory_buffer()
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
import torch import torch
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
###### BIAS GELU FUSION/ NO AUTOGRAD ################ ###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
......
...@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax ...@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -544,6 +539,7 @@ class ParallelTransformer(MegatronModule): ...@@ -544,6 +539,7 @@ class ParallelTransformer(MegatronModule):
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers. # Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
...@@ -607,8 +603,22 @@ class ParallelTransformer(MegatronModule): ...@@ -607,8 +603,22 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
# Make sure memory is freed. def distribute_checkpointed_activations_helper(layer_number):
mpu.reset_checkpointed_activations_memory_buffer() """Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = (layer_number > 0)
is_first_pipeline_stage = (
mpu.get_pipeline_model_parallel_rank() == 0)
return self.distribute_checkpointed_activations and \
(not_first_layer_in_pipeline_stage or is_first_pipeline_stage)
if self.activations_checkpoint_method == 'uniform': if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
...@@ -618,6 +628,7 @@ class ParallelTransformer(MegatronModule): ...@@ -618,6 +628,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers), custom(l, l + self.activations_checkpoint_num_layers),
distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block': elif self.activations_checkpoint_method == 'block':
...@@ -628,6 +639,7 @@ class ParallelTransformer(MegatronModule): ...@@ -628,6 +639,7 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers: if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
......
...@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region ...@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks from .random import split_tensor_into_1d_equal_chunks
......
...@@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size ...@@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
if args.virtual_pipeline_model_parallel_size is not None:
num_layers = num_layers // args.virtual_pipeline_model_parallel_size
if args.activations_checkpoint_method == 'uniform':
assert num_layers % args.activations_checkpoint_num_layers == 0, \
'total number of layers is not divisible by checkpoint-chunk_size'
num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
elif args.activations_checkpoint_method == 'block':
assert args.activations_checkpoint_num_layers <= num_layers, \
'total number of layers is fewer than the number of layers to checkpoint'
num_checkpointer_layers = args.activations_checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False)
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
def _set_cuda_rng_state(new_state, device=-1): def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
...@@ -110,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -110,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks.""" """Break a tensor into equal 1D chunks."""
data = tensor.view(-1) partition_size = torch.numel(tensor) // \
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size() get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank() start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
return data[start_index:end_index] if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
...@@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset. tracked/set/reset.
""" """
@staticmethod @staticmethod
def forward(ctx, run_function, *args): def forward(ctx, run_function, distribute_checkpointed_activations, *args):
ctx.run_function = run_function ctx.run_function = run_function
ctx.distribute_checkpointed_activations \
= distribute_checkpointed_activations
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep # Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add( new_buffer=True)
args[0].data)
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
return outputs return outputs
@staticmethod @staticmethod
...@@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), " raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape) inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
...@@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs) for inp in detached_inputs)
return (None,) + grads return (None, None) + grads
def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args):
"""Checkpoint a model or part of the model. """Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint.""" This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment