Commit f602ac56 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

(1)support pyt20.12 compatibility, (2) arg naming update

change dummy_handler to nullcontext
parent c7fef593
......@@ -66,9 +66,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args()
if args.lazy_mpu_init:
args.use_cpu_initialization=True
......@@ -232,7 +229,7 @@ def write_args_to_tensorboard():
global_step=args.iteration)
def _set_jit_fusion_options():
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
......@@ -253,41 +250,47 @@ def _set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
_warmup_jit_function()
def warmup_jit_function():
def _warmup_jit_function():
""" Compilie JIT functions before the main training steps """
args = get_args()
if args.bf16:
p = torch.bfloat16
dtype = torch.bfloat16
elif args.fp16:
p = torch.float16
dtype = torch.float16
else:
p = torch.float32
dtype = torch.float32
# Warmup fused bias+gelu
b = torch.rand(int(args.hidden_size * 4 / args.tensor_model_parallel_size),
dtype=p, device='cuda')
x = torch.rand((args.seq_length, args.micro_batch_size,
int(args.hidden_size * 4 / args.tensor_model_parallel_size)),
dtype=p, device='cuda')
# Warmup JIT fusions with the input grad_enable state at both forward
bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype, device='cuda')
input = torch.rand((args.seq_length, args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size),
dtype=dtype, device='cuda')
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for b_grad, x_grad in zip([True, True], [False, True]):
b.requires_grad, x.requires_grad = b_grad, x_grad
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
y = bias_gelu(b, x)
del b, x, y
output = bias_gelu(bias, input)
del bias, input, output
# Warmup fused bias+dropout+add
input_size = (args.seq_length, args.micro_batch_size, args.hidden_size)
x = torch.rand(input_size, dtype=p, device='cuda')
r = torch.rand(input_size, dtype=p, device='cuda')
b = torch.rand((args.hidden_size), dtype=p, device='cuda').expand_as(r)
# Warmup JIT fusions with the input grad_enable state at both forward
input = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
residual = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for x_grad, b_grad, r_grad in zip([False, True], [True, True], [True, True]):
x.requires_grad, b.requires_grad, r.requires_grad = x_grad, b_grad, r_grad
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5):
y = bias_dropout_add_fused_train(x, b, r, 0.1)
del b, x, r, y
output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
del bias, input, residual, output
torch.cuda.empty_cache()
......@@ -15,6 +15,7 @@
"""Transformer."""
import math
from contextlib import nullcontext
import torch
import torch.nn.functional as F
......@@ -531,6 +532,13 @@ class ParallelTransformerLayer(MegatronModule):
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
......@@ -564,11 +572,12 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
......@@ -589,11 +598,12 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
......@@ -608,11 +618,12 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input
if self.drop_path is None:
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
......
......@@ -218,7 +218,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
import fused_dense_cuda
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
......@@ -236,6 +235,7 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
import fused_dense_cuda
fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
grad_weight = None
else:
......
......@@ -42,7 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import warmup_jit_function
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
......@@ -100,7 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
warmup_jit_function()
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment