Commit 5d29769c authored by mohammad's avatar mohammad
Browse files

addressed Jareds comments

parent d6c4248b
...@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={},
# Mixed precision checks. # Mixed precision checks.
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
_print_args(args) _print_args(args)
return args return args
......
...@@ -162,13 +162,4 @@ def _initialize_mem_buffs(): ...@@ -162,13 +162,4 @@ def _initialize_mem_buffs():
# Initialize memory for checkpointed activations. # Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
per_layer = args.batch_size * args.max_position_embeddings * \ mpu.init_checkpointed_activations_memory_buffer()
args.hidden_size // args.model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
mpu.init_checkpointed_activations_memory_buffer(numel, dtype)
...@@ -24,6 +24,7 @@ from torch import _C ...@@ -24,6 +24,7 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron import get_args
from megatron.memory import allocate_mem_buff from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
...@@ -40,9 +41,23 @@ _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' ...@@ -40,9 +41,23 @@ _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
def init_checkpointed_activations_memory_buffer(numel, dtype): def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations.""" """Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False) 'checkpointed activations', numel, dtype, track_usage=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment