Unverified Commit 8d4761ad authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Deterministic JIT warmup (#216)



* deterministic JIT warmup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ec0d40d6
...@@ -157,8 +157,11 @@ def bias_dropout_add_fused_inference( ...@@ -157,8 +157,11 @@ def bias_dropout_add_fused_inference(
def warmup_jit_bias_dropout_add( def warmup_jit_bias_dropout_add(
hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None: ) -> None:
"""Compilie BDA JIT function before the main training steps""" """Compile BDA JIT function before the main training steps"""
# Warmup fused bias+dropout+add
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state = torch.cuda.get_rng_state()
inp = torch.rand( inp = torch.rand(
(seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda" (seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda"
) )
...@@ -178,7 +181,9 @@ def warmup_jit_bias_dropout_add( ...@@ -178,7 +181,9 @@ def warmup_jit_bias_dropout_add(
for _ in range(5): for _ in range(5):
output = bias_dropout_add_fused_train(inp, bias, residual, dropout_rate) output = bias_dropout_add_fused_train(inp, bias, residual, dropout_rate)
del bias, inp, residual, output del bias, inp, residual, output
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
def warmup_jit_bias_dropout_add_all_dtypes( def warmup_jit_bias_dropout_add_all_dtypes(
...@@ -195,8 +200,11 @@ def warmup_jit_bias_gelu( ...@@ -195,8 +200,11 @@ def warmup_jit_bias_gelu(
seq_length: int, seq_length: int,
micro_batch_size: int, micro_batch_size: int,
) -> None: ) -> None:
"""Compilie bias-gelu JIT function before the main training steps""" """Compile bias-gelu JIT function before the main training steps"""
# Warmup fused bias+gelu
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state = torch.cuda.get_rng_state()
bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda") bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda")
inp = torch.rand( inp = torch.rand(
(seq_length, micro_batch_size, ffn_hidden_size_per_partition), (seq_length, micro_batch_size, ffn_hidden_size_per_partition),
...@@ -211,6 +219,9 @@ def warmup_jit_bias_gelu( ...@@ -211,6 +219,9 @@ def warmup_jit_bias_gelu(
output = bias_gelu_fused(inp, bias) output = bias_gelu_fused(inp, bias)
del bias, inp, output del bias, inp, output
torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
def warmup_jit_bias_gelu_all_dtypes( def warmup_jit_bias_gelu_all_dtypes(
ffn_hidden_size: int, seq_length: int, micro_batch_size: int ffn_hidden_size: int, seq_length: int, micro_batch_size: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment