Unverified Commit 8d4761ad authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Deterministic JIT warmup (#216)



* deterministic JIT warmup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ec0d40d6
......@@ -157,8 +157,11 @@ def bias_dropout_add_fused_inference(
def warmup_jit_bias_dropout_add(
hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
"""Compilie BDA JIT function before the main training steps"""
# Warmup fused bias+dropout+add
"""Compile BDA JIT function before the main training steps"""
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state = torch.cuda.get_rng_state()
inp = torch.rand(
(seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda"
)
......@@ -178,7 +181,9 @@ def warmup_jit_bias_dropout_add(
for _ in range(5):
output = bias_dropout_add_fused_train(inp, bias, residual, dropout_rate)
del bias, inp, residual, output
torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
def warmup_jit_bias_dropout_add_all_dtypes(
......@@ -195,8 +200,11 @@ def warmup_jit_bias_gelu(
seq_length: int,
micro_batch_size: int,
) -> None:
"""Compilie bias-gelu JIT function before the main training steps"""
# Warmup fused bias+gelu
"""Compile bias-gelu JIT function before the main training steps"""
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state = torch.cuda.get_rng_state()
bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda")
inp = torch.rand(
(seq_length, micro_batch_size, ffn_hidden_size_per_partition),
......@@ -211,6 +219,9 @@ def warmup_jit_bias_gelu(
output = bias_gelu_fused(inp, bias)
del bias, inp, output
torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
def warmup_jit_bias_gelu_all_dtypes(
ffn_hidden_size: int, seq_length: int, micro_batch_size: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment