Commit 7f9a48ba authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into sequence_parallel

parents 00129014 d2394294
......@@ -31,6 +31,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
def initialize_megatron(extra_args_provider=None, args_defaults={},
......@@ -64,9 +66,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args()
if args.lazy_mpu_init:
args.use_cpu_initialization=True
......@@ -230,7 +229,7 @@ def write_args_to_tensorboard():
global_step=args.iteration)
def _set_jit_fusion_options():
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
......@@ -251,3 +250,47 @@ def _set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
_warmup_jit_function()
def _warmup_jit_function():
""" Compilie JIT functions before the main training steps """
args = get_args()
if args.bf16:
dtype = torch.bfloat16
elif args.fp16:
dtype = torch.float16
else:
dtype = torch.float32
# Warmup fused bias+gelu
bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype, device='cuda')
input = torch.rand((args.seq_length, args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size),
dtype=dtype, device='cuda')
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_gelu(bias, input)
del bias, input, output
# Warmup fused bias+dropout+add
input = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
residual = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5):
output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
del bias, input, residual, output
torch.cuda.empty_cache()
......@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from megatron.mpu import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
......@@ -113,6 +115,15 @@ class MixedFusedLayerNorm(torch.nn.Module):
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return FastLayerNormFN.apply(
output = FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(inp = output,
requires_grad = input.requires_grad,
keep_graph = True)
return output
......@@ -15,7 +15,7 @@
"""Transformer."""
import math
import contextlib
from contextlib import nullcontext
import torch
import torch.nn.functional as F
......@@ -593,6 +593,13 @@ class ParallelTransformerLayer(MegatronModule):
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
......@@ -626,8 +633,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
......@@ -653,8 +659,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
......@@ -674,8 +679,7 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
......@@ -909,7 +913,7 @@ class ParallelTransformer(MegatronModule):
if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork()
else:
rng_context = contextlib.nullcontext
rng_context = nullcontext
with rng_context:
# Forward pass.
......
......@@ -241,7 +241,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
import fused_dense_cuda
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
......@@ -296,6 +295,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if ctx.gradient_accumulation_fusion:
import fused_dense_cuda
fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
grad_weight = None
else:
......
......@@ -43,6 +43,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
......@@ -100,6 +101,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment