Commit c92f10bd authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into tridao-flashattn

parents 9200e43a b7071993
...@@ -28,6 +28,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -28,6 +28,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_vision_args(parser) parser = _add_vision_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_inference_args(parser) parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -304,6 +305,18 @@ def validate_args(args, defaults={}): ...@@ -304,6 +305,18 @@ def validate_args(args, defaults={}):
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# Tranformer-Engine/FP8 related checking
if args.fp8_e4m3 or args.fp8_hybrid:
assert args.transformer_impl == 'transformer_engine', \
'transformer-engine required for fp8 training and inference'
assert not (args.fp8_e4m3 and args.fp8_hybrid), \
'cannot train with both fp8 e4m3 and hybrid formatting'
if args.fp16:
assert args.transformer_impl == 'local', \
'transformer-engine not yet approved for fp16 training and inference'
if args.recompute_granularity == 'selective': if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \ assert args.recompute_method is None, \
'recompute method is not yet supported for ' \ 'recompute method is not yet supported for ' \
...@@ -355,6 +368,33 @@ def _check_arg_is_not_none(args, arg): ...@@ -355,6 +368,33 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg) assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')
group.add_argument('--fp8-e4m3', action='store_true',
help='E4M3 TransformerLayer', dest='fp8_e4m3')
group.add_argument('--fp8-hybrid', action='store_true',
help='Hybrid FP8 TransformerLayer', dest='fp8_hybrid')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad')
group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8', dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8', dest='fp8_interval')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.',
dest='transformer_impl')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'],
help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo')
return parser
def _add_inference_args(parser): def _add_inference_args(parser):
group = parser.add_argument_group(title='inference') group = parser.add_argument_group(title='inference')
......
...@@ -18,11 +18,14 @@ def load(args): ...@@ -18,11 +18,14 @@ def load(args):
# Check if cuda 11 is installed for compute capability 8.0 # Check if cuda 11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version( _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME) cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 7:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')
# Build path # Build path
srcpath = pathlib.Path(__file__).parent.absolute() srcpath = pathlib.Path(__file__).parent.absolute()
...@@ -75,11 +78,14 @@ def load(args): ...@@ -75,11 +78,14 @@ def load(args):
# Mixed precision fused layer norm. # Mixed precision fused layer norm.
# ================================= # =================================
extra_hopper_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__']
extra_cuda_flags = ['-maxrregcount=50'] extra_cuda_flags = ['-maxrregcount=50']
sources=[srcpath / 'layer_norm_cuda.cpp', sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu'] srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags) "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags + extra_hopper_flags)
# ================================= # =================================
# Fused gradient accumulation to weight gradient computation of linear layer # Fused gradient accumulation to weight gradient computation of linear layer
...@@ -89,7 +95,7 @@ def load(args): ...@@ -89,7 +95,7 @@ def load(args):
sources=[srcpath / 'fused_weight_gradient_dense.cpp', sources=[srcpath / 'fused_weight_gradient_dense.cpp',
srcpath / 'fused_weight_gradient_dense.cu'] srcpath / 'fused_weight_gradient_dense.cu']
fused_dense_cuda = _cpp_extention_load_helper( fused_dense_cuda = _cpp_extention_load_helper(
"fused_dense_cuda", sources, []) "fused_dense_cuda", sources, extra_hopper_flags)
def _get_cuda_bare_metal_version(cuda_dir): def _get_cuda_bare_metal_version(cuda_dir):
......
...@@ -6,7 +6,7 @@ from contextlib import nullcontext ...@@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_timers, get_args, core from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule from .module import MegatronModule
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
...@@ -25,7 +25,6 @@ try: ...@@ -25,7 +25,6 @@ try:
except ImportError: except ImportError:
flash_attn_unpadded_func = None flash_attn_unpadded_func = None
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
n: number of attention heads n: number of attention heads
...@@ -890,6 +889,7 @@ class ParallelTransformer(MegatronModule): ...@@ -890,6 +889,7 @@ class ParallelTransformer(MegatronModule):
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.transformer_impl = args.transformer_impl
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity self.recompute_granularity = args.recompute_granularity
...@@ -900,6 +900,31 @@ class ParallelTransformer(MegatronModule): ...@@ -900,6 +900,31 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
# Transformer Engine Init.
if self.transformer_impl == 'transformer_engine':
global transformer_engine
import transformer_engine
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
self.fp8_recipe = None
self.fp8_group = mpu.get_data_parallel_group()
if self.use_fp8:
if args.fp8_e4m3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=args.fp8_margin,
interval=args.fp8_interval,
fp8_format=fp8_format,
amax_history_len=args.fp8_amax_history_len,
amax_compute_algo=args.fp8_amax_compute_algo,
override_linear_precision=(False, False, not args.fp8_wgrad),
)
self.num_microbatches_in_previous_step = -1
self.microbatch_count = 0
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Number of layers. # Number of layers.
self.num_layers = _get_num_layers( self.num_layers = _get_num_layers(
args, args,
...@@ -910,6 +935,7 @@ class ParallelTransformer(MegatronModule): ...@@ -910,6 +935,7 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
if args.transformer_impl == 'local':
return ParallelTransformerLayer( return ParallelTransformerLayer(
init_method, init_method,
output_layer_init_method, output_layer_init_method,
...@@ -917,6 +943,35 @@ class ParallelTransformer(MegatronModule): ...@@ -917,6 +943,35 @@ class ParallelTransformer(MegatronModule):
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type, self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1]) drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
return transformer_engine.pytorch.TransformerLayer(
args.hidden_size,
args.ffn_hidden_size,
args.num_attention_heads,
layernorm_epsilon=args.layernorm_epsilon,
hidden_dropout=args.hidden_dropout,
attention_dropout=args.attention_dropout,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=args.kv_channels,
self_attn_mask_type=self_attn_mask_type.name,
tp_group=mpu.get_tensor_model_parallel_group(),
get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
fuse_wgrad_accumulation=args.gradient_accumulation_fusion,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
sequence_parallel=args.sequence_parallel,
params_dtype=args.params_dtype,
apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=self.drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True)
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \ 'num_layers_per_stage must be divisible by ' \
...@@ -976,19 +1031,20 @@ class ParallelTransformer(MegatronModule): ...@@ -976,19 +1031,20 @@ class ParallelTransformer(MegatronModule):
return self.layers[layer_number] return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask, def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask): encoder_output, enc_dec_attn_mask, is_first_microbatch):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end, is_transformer_engine=False):
def custom_forward(*inputs): def custom_forward(*args, **kwargs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end): for index in range(start, end):
layer = self._get_layer(index) layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) x_ = layer(*args, **kwargs)
return x_ return x_
def custom_forward_transformer_engine(*args, **kwargs):
return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
if not is_transformer_engine:
return custom_forward return custom_forward
else:
return custom_forward_transformer_engine
if self.recompute_method == 'uniform': if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
...@@ -996,10 +1052,19 @@ class ParallelTransformer(MegatronModule): ...@@ -996,10 +1052,19 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + self.recompute_num_layers, is_transformer_engine=True),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.recompute_num_layers l += self.recompute_num_layers
elif self.recompute_method == 'block': elif self.recompute_method == 'block':
...@@ -1008,10 +1073,22 @@ class ParallelTransformer(MegatronModule): ...@@ -1008,10 +1073,22 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.recompute_num_layers: if l < self.recompute_num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + 1, is_transformer_engine=True),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
if self.transformer_impl == 'transformer_engine':
hidden_states = custom(l, l + 1, is_transformer_engine=True)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -1071,21 +1148,48 @@ class ParallelTransformer(MegatronModule): ...@@ -1071,21 +1148,48 @@ class ParallelTransformer(MegatronModule):
rng_context = nullcontext() rng_context = nullcontext()
with rng_context: with rng_context:
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
with transformer_engine.pytorch.fp8_autocast(
enabled=self.use_fp8,
fp8_recipe=self.fp8_recipe,
fp8_group=self.fp8_group
) if self.use_fp8 else nullcontext():
# Determine if the current iteration is first microbatch
if self.num_microbatches_in_previous_step != get_num_microbatches():
self.microbatch_count = 0 # Reset count on new batch size rampup interval
self.num_microbatches_in_previous_step = get_num_microbatches()
is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0
# Forward pass. # Forward pass.
if self.recompute_granularity == 'full': if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
encoder_output, encoder_output,
enc_dec_attn_mask) enc_dec_attn_mask,
is_first_microbatch)
else: else:
forward_kwargs = {
'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params,
}
if self.transformer_impl == 'transformer_engine':
forward_kwargs['is_first_microbatch'] = is_first_microbatch
forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self._get_layer(index) layer = self._get_layer(index)
hidden_states = layer( hidden_states = layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_output=encoder_output, **forward_kwargs)
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params) # Skip counter update for eval and activation checkpointing
if torch.is_grad_enabled() and self.training:
self.microbatch_count += 1
# Final layer norm. # Final layer norm.
if self.post_process and self.post_layer_norm: if self.post_process and self.post_layer_norm:
......
...@@ -26,6 +26,7 @@ from megatron.checkpointing import load_checkpoint ...@@ -26,6 +26,7 @@ from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model import GPTModel
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
...@@ -251,6 +252,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -251,6 +252,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
# Disallow training and inference with Transformer Engine
# for non-GPT models
args.allow_transformer_engine = all([type(m) == GPTModel for m in model])
assert args.allow_transformer_engine or args.transformer_impl == 'local', \
'Transformer Engine is only approved for GPT models'
# Set tensor model parallel attributes if not set. # Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these # Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes # attributes set for them. We should make sure the default attributes
......
...@@ -113,4 +113,6 @@ if __name__ == "__main__": ...@@ -113,4 +113,6 @@ if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder, ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
)
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0, is_last_rank from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.core import mpu from megatron.core import parallel_state, tensor_parallel
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
...@@ -90,10 +90,10 @@ def forward_step(batch, model, eval_metric): ...@@ -90,10 +90,10 @@ def forward_step(batch, model, eval_metric):
send_forward(output) send_forward(output)
if mpu.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
# For loss, return the unreduced loss. # For loss, return the unreduced loss.
if eval_metric == 'loss': if eval_metric == 'loss':
losses = mpu.tensor_parallel.vocab_parallel_cross_entropy( losses = tensor_parallel.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous()) output.contiguous().float(), labels.contiguous())
loss = torch.sum( loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float()) losses.view(-1) * loss_mask.contiguous().view(-1).float())
...@@ -129,9 +129,9 @@ def evaluate(data_loader, model, eval_metric): ...@@ -129,9 +129,9 @@ def evaluate(data_loader, model, eval_metric):
output = forward_step(batch, model, eval_metric) output = forward_step(batch, model, eval_metric)
# Reduce across processes. # Reduce across processes.
if mpu.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
torch.distributed.all_reduce(output, torch.distributed.all_reduce(output,
group=mpu.get_data_parallel_group()) group=parallel_state.get_data_parallel_group())
total_output += output total_output += output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment