Unverified Commit 8eec2004 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Disable torch autocast context in rope forward pass (#2240)


Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7ad130ef
......@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast():
rope_layer = RotaryPositionEmbedding(128)
rope_embeddings_no_autocast = rope_layer(max_seq_len=1024)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
rope_embeddings_autocast = rope_layer(max_seq_len=1024)
torch.testing.assert_close(
rope_embeddings_no_autocast.to(dtype=torch.bfloat16),
rope_embeddings_autocast.to(dtype=torch.bfloat16),
atol=1e-8,
rtol=1e-8,
)
......@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module):
"""
Create rotary position embedding frequencies.
This function is particularly sensitive to the use of mixed precision, so we disable the
autocast context if it is enabled.
Parameters
----------
max_seq_len: int
......@@ -73,6 +76,7 @@ class RotaryPositionEmbedding(torch.nn.Module):
offset: int, default = 0
Fixed offset for frequencies.
"""
with torch.autocast(enabled=False, device_type="cuda"):
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment