Unverified Commit ded8b9bd authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix numerics for activation recompute (#327)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2a81e939
......@@ -220,7 +220,7 @@ def warmup_jit_bias_gelu(
bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda")
inp = torch.rand(
(seq_length, micro_batch_size, ffn_hidden_size_per_partition),
(seq_length * micro_batch_size, ffn_hidden_size_per_partition),
dtype=dtype,
device="cuda",
)
......@@ -229,8 +229,9 @@ def warmup_jit_bias_gelu(
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, inp.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_gelu_fused(inp, bias)
del bias, inp, output
_ = bias_gelu_fused_(inp, bias)
_ = gelu_fused_(inp)
del bias, inp
torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment