Commit f602ac56 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

(1)support pyt20.12 compatibility, (2) arg naming update

change dummy_handler to nullcontext
parent c7fef593
...@@ -66,9 +66,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -66,9 +66,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init) _set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
args.use_cpu_initialization=True args.use_cpu_initialization=True
...@@ -232,7 +229,7 @@ def write_args_to_tensorboard(): ...@@ -232,7 +229,7 @@ def write_args_to_tensorboard():
global_step=args.iteration) global_step=args.iteration)
def _set_jit_fusion_options(): def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options.""" """Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
...@@ -253,41 +250,47 @@ def _set_jit_fusion_options(): ...@@ -253,41 +250,47 @@ def _set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_override_can_fuse_on_gpu(True)
_warmup_jit_function()
def warmup_jit_function(): def _warmup_jit_function():
""" Compilie JIT functions before the main training steps """ """ Compilie JIT functions before the main training steps """
args = get_args() args = get_args()
if args.bf16: if args.bf16:
p = torch.bfloat16 dtype = torch.bfloat16
elif args.fp16: elif args.fp16:
p = torch.float16 dtype = torch.float16
else: else:
p = torch.float32 dtype = torch.float32
# Warmup fused bias+gelu # Warmup fused bias+gelu
b = torch.rand(int(args.hidden_size * 4 / args.tensor_model_parallel_size), bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=p, device='cuda') dtype=dtype, device='cuda')
x = torch.rand((args.seq_length, args.micro_batch_size, input = torch.rand((args.seq_length, args.micro_batch_size,
int(args.hidden_size * 4 / args.tensor_model_parallel_size)), args.ffn_hidden_size // args.tensor_model_parallel_size),
dtype=p, device='cuda') dtype=dtype, device='cuda')
# Warmup JIT fusions with the input grad_enable state at both forward # Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation # prop and recomputation
for b_grad, x_grad in zip([True, True], [False, True]): for bias_grad, input_grad in zip([True, True], [False, True]):
b.requires_grad, x.requires_grad = b_grad, x_grad bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5): for _ in range(5):
y = bias_gelu(b, x) output = bias_gelu(bias, input)
del b, x, y del bias, input, output
# Warmup fused bias+dropout+add # Warmup fused bias+dropout+add
input_size = (args.seq_length, args.micro_batch_size, args.hidden_size) input = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size),
x = torch.rand(input_size, dtype=p, device='cuda') dtype=dtype, device='cuda')
r = torch.rand(input_size, dtype=p, device='cuda') residual = torch.rand((args.seq_length, args.micro_batch_size, args.hidden_size),
b = torch.rand((args.hidden_size), dtype=p, device='cuda').expand_as(r) dtype=dtype, device='cuda')
# Warmup JIT fusions with the input grad_enable state at both forward bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation # prop and recomputation
for x_grad, b_grad, r_grad in zip([False, True], [True, True], [True, True]): for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
x.requires_grad, b.requires_grad, r.requires_grad = x_grad, b_grad, r_grad input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5): for _ in range(5):
y = bias_dropout_add_fused_train(x, b, r, 0.1) output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
del b, x, r, y del bias, input, residual, output
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Transformer.""" """Transformer."""
import math import math
from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -531,6 +532,13 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -531,6 +532,13 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
self.mlp = ParallelMLP(init_method, output_layer_init_method) self.mlp = ParallelMLP(init_method, output_layer_init_method)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None):
...@@ -564,11 +572,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -564,11 +572,12 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
layernorm_input = bias_dropout_add_func( with self.bias_dropout_add_exec_handler():
attention_output, layernorm_input = bias_dropout_add_func(
attention_bias.expand_as(residual), attention_output,
residual, attention_bias.expand_as(residual),
self.hidden_dropout) residual,
self.hidden_dropout)
else: else:
out = torch.nn.functional.dropout(attention_output + attention_bias, out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout, p=self.hidden_dropout,
...@@ -589,11 +598,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -589,11 +598,12 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
layernorm_input = bias_dropout_add_func( with self.bias_dropout_add_exec_handler():
attention_output, layernorm_input = bias_dropout_add_func(
attention_bias.expand_as(residual), attention_output,
residual, attention_bias.expand_as(residual),
self.hidden_dropout) residual,
self.hidden_dropout)
# Layer norm post the decoder attention # Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input) layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
...@@ -608,11 +618,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -608,11 +618,12 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input residual = layernorm_input
if self.drop_path is None: if self.drop_path is None:
output = bias_dropout_add_func( with self.bias_dropout_add_exec_handler():
mlp_output, output = bias_dropout_add_func(
mlp_bias.expand_as(residual), mlp_output,
residual, mlp_bias.expand_as(residual),
self.hidden_dropout) residual,
self.hidden_dropout)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout, p=self.hidden_dropout,
......
...@@ -218,7 +218,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): ...@@ -218,7 +218,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
import fused_dense_cuda
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
...@@ -236,6 +235,7 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function): ...@@ -236,6 +235,7 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
import fused_dense_cuda
fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad) fused_dense_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
grad_weight = None grad_weight = None
else: else:
......
...@@ -42,7 +42,7 @@ from megatron.model import ModelType ...@@ -42,7 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import warmup_jit_function from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
...@@ -100,7 +100,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -100,7 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider, initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults)
warmup_jit_function() # Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value. # Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of # This will be closer to what scheduler will see (outside of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment