Commit 3499542e authored by John Kamalu's avatar John Kamalu
Browse files

Transformer Engine Integration Rebase

parent 52e63688
......@@ -28,6 +28,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -304,6 +305,18 @@ def validate_args(args, defaults={}):
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# Tranformer-Engine/FP8 related checking
if args.fp8_e4m3 or args.fp8_hybrid:
assert args.transformer_impl == 'transformer_engine', \
'transformer-engine required for fp8 training and inference'
assert not (args.fp8_e4m3 and args.fp8_hybrid), \
'cannot train with both fp8 e4m3 and hybrid formatting'
if args.fp16:
assert args.transformer_impl == 'local', \
'transformer-engine not yet approved for fp16 training and inference'
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
......@@ -355,6 +368,33 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')
group.add_argument('--fp8-e4m3', action='store_true',
help='E4M3 TransformerLayer', dest='fp8_e4m3')
group.add_argument('--fp8-hybrid', action='store_true',
help='Hybrid FP8 TransformerLayer', dest='fp8_hybrid')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad')
group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8', dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8', dest='fp8_interval')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.',
dest='transformer_impl')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'],
help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo')
return parser
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
......
......@@ -18,11 +18,14 @@ def load(args):
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 7:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
......@@ -75,11 +78,14 @@ def load(args):
# Mixed precision fused layer norm.
# =================================
extra_hopper_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__']
extra_cuda_flags = ['-maxrregcount=50']
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags + extra_hopper_flags)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
......@@ -89,7 +95,7 @@ def load(args):
sources=[srcpath / 'fused_weight_gradient_dense.cpp',
srcpath / 'fused_weight_gradient_dense.cu']
fused_dense_cuda = _cpp_extention_load_helper(
"fused_dense_cuda", sources, [])
"fused_dense_cuda", sources, extra_hopper_flags)
def _get_cuda_bare_metal_version(cuda_dir):
......
......@@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch
import torch.nn.functional as F
from megatron import get_timers, get_args, core
from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
......@@ -15,7 +15,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
......@@ -810,6 +809,7 @@ class ParallelTransformer(MegatronModule):
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
self.transformer_impl = args.transformer_impl
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
......@@ -820,6 +820,31 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel
# Transformer Engine Init.
if self.transformer_impl == 'transformer_engine':
global transformer_engine
import transformer_engine
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
self.fp8_recipe = None
self.fp8_group = mpu.get_data_parallel_group()
if self.use_fp8:
if args.fp8_e4m3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=args.fp8_margin,
interval=args.fp8_interval,
fp8_format=fp8_format,
amax_history_len=args.fp8_amax_history_len,
amax_compute_algo=args.fp8_amax_compute_algo,
override_linear_precision=(False, False, not args.fp8_wgrad),
)
self.num_microbatches_in_previous_step = -1
self.microbatch_count = 0
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Number of layers.
self.num_layers = _get_num_layers(
args,
......@@ -830,6 +855,7 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def build_layer(layer_number):
if args.transformer_impl == 'local':
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
......@@ -837,6 +863,35 @@ class ParallelTransformer(MegatronModule):
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
return transformer_engine.pytorch.TransformerLayer(
args.hidden_size,
args.ffn_hidden_size,
args.num_attention_heads,
layernorm_epsilon=args.layernorm_epsilon,
hidden_dropout=args.hidden_dropout,
attention_dropout=args.attention_dropout,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=args.kv_channels,
self_attn_mask_type=self_attn_mask_type.name,
tp_group=mpu.get_tensor_model_parallel_group(),
get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
fuse_wgrad_accumulation=args.gradient_accumulation_fusion,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
sequence_parallel=args.sequence_parallel,
params_dtype=args.params_dtype,
apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=self.drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
......@@ -896,19 +951,20 @@ class ParallelTransformer(MegatronModule):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
encoder_output, enc_dec_attn_mask, is_first_microbatch):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
def custom(start, end, is_transformer_engine=False):
def custom_forward(*args, **kwargs):
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
x_ = layer(*args, **kwargs)
return x_
def custom_forward_transformer_engine(*args, **kwargs):
return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
if not is_transformer_engine:
return custom_forward
else:
return custom_forward_transformer_engine
if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
......@@ -916,10 +972,19 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + self.recompute_num_layers, is_transformer_engine=True),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.recompute_num_layers
elif self.recompute_method == 'block':
......@@ -928,10 +993,22 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + 1, is_transformer_engine=True),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
if self.transformer_impl == 'transformer_engine':
hidden_states = custom(l, l + 1, is_transformer_engine=True)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
......@@ -991,21 +1068,48 @@ class ParallelTransformer(MegatronModule):
rng_context = nullcontext()
with rng_context:
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
with transformer_engine.pytorch.fp8_autocast(
enabled=self.use_fp8,
fp8_recipe=self.fp8_recipe,
fp8_group=self.fp8_group
) if self.use_fp8 else nullcontext():
# Determine if the current iteration is first microbatch
if self.num_microbatches_in_previous_step != get_num_microbatches():
self.microbatch_count = 0 # Reset count on new batch size rampup interval
self.num_microbatches_in_previous_step = get_num_microbatches()
is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0
# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
enc_dec_attn_mask,
is_first_microbatch)
else:
forward_kwargs = {
'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params,
}
if self.transformer_impl == 'transformer_engine':
forward_kwargs['is_first_microbatch'] = is_first_microbatch
forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
**forward_kwargs)
# Skip counter update for eval and activation checkpointing
if torch.is_grad_enabled() and self.training:
self.microbatch_count += 1
# Final layer norm.
if self.post_process and self.post_layer_norm:
......
......@@ -26,6 +26,7 @@ from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.model import GPTModel
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
......@@ -251,6 +252,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
if not isinstance(model, list):
model = [model]
# Disallow training and inference with Transformer Engine
# for non-GPT models
args.allow_transformer_engine = all([type(m) == GPTModel for m in model])
assert args.allow_transformer_engine or args.transformer_impl == 'local', \
'Transformer Engine is only approved for GPT models'
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
......
......@@ -113,4 +113,6 @@ if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment