Commit 214a9a08 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comment

parent 28b6ffce
......@@ -266,13 +266,7 @@ def _warmup_jit_function():
# Warmup fused bias+gelu
bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype, device='cuda')
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
input = torch.rand((seq_length, args.micro_batch_size,
input = torch.rand((args.seq_length, args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size),
dtype=dtype, device='cuda')
# Warmup JIT fusions with the input grad_enable state of both forward
......@@ -284,6 +278,10 @@ def _warmup_jit_function():
del bias, input, output
# Warmup fused bias+dropout+add
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment