Commit c7fef593 authored by Sangkug Lym's avatar Sangkug Lym
Browse files

jit functions warmpups to (1) match fprop and recompute results and (2) remove grad_enable

parent 15f6bb1b
......@@ -31,6 +31,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
def initialize_megatron(extra_args_provider=None, args_defaults={},
......@@ -251,3 +253,41 @@ def _set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
def warmup_jit_function():
""" Compilie JIT functions before the main training steps """
args = get_args()
if args.bf16:
p = torch.bfloat16
elif args.fp16:
p = torch.float16
else:
p = torch.float32
# Warmup fused bias+gelu
b = torch.rand(int(args.hidden_size * 4 / args.tensor_model_parallel_size),
dtype=p, device='cuda')
x = torch.rand((args.seq_length, args.micro_batch_size,
int(args.hidden_size * 4 / args.tensor_model_parallel_size)),
dtype=p, device='cuda')
# Warmup JIT fusions with the input grad_enable state at both forward
# prop and recomputation
for b_grad, x_grad in zip([True, True], [False, True]):
b.requires_grad, x.requires_grad = b_grad, x_grad
for _ in range(5):
y = bias_gelu(b, x)
del b, x, y
# Warmup fused bias+dropout+add
input_size = (args.seq_length, args.micro_batch_size, args.hidden_size)
x = torch.rand(input_size, dtype=p, device='cuda')
r = torch.rand(input_size, dtype=p, device='cuda')
b = torch.rand((args.hidden_size), dtype=p, device='cuda').expand_as(r)
# Warmup JIT fusions with the input grad_enable state at both forward
# prop and recomputation
for x_grad, b_grad, r_grad in zip([False, True], [True, True], [True, True]):
x.requires_grad, b.requires_grad, r.requires_grad = x_grad, b_grad, r_grad
for _ in range(5):
y = bias_dropout_add_fused_train(x, b, r, 0.1)
del b, x, r, y
torch.cuda.empty_cache()
......@@ -564,13 +564,11 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
......@@ -591,13 +589,11 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
......@@ -612,13 +608,11 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
......
......@@ -42,6 +42,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import warmup_jit_function
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
......@@ -99,6 +100,7 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
warmup_jit_function()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment