Unverified Commit 8eec2004 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Disable torch autocast context in rope forward pass (#2240)


Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7ad130ef
...@@ -373,3 +373,19 @@ def test_fused_qkv_rope( ...@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
if not isinstance(start_positions, torch.Tensor): if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused) torch.testing.assert_close(grad_fused, grad_unfused)
def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast():
rope_layer = RotaryPositionEmbedding(128)
rope_embeddings_no_autocast = rope_layer(max_seq_len=1024)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
rope_embeddings_autocast = rope_layer(max_seq_len=1024)
torch.testing.assert_close(
rope_embeddings_no_autocast.to(dtype=torch.bfloat16),
rope_embeddings_autocast.to(dtype=torch.bfloat16),
atol=1e-8,
rtol=1e-8,
)
...@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module):
""" """
Create rotary position embedding frequencies. Create rotary position embedding frequencies.
This function is particularly sensitive to the use of mixed precision, so we disable the
autocast context if it is enabled.
Parameters Parameters
---------- ----------
max_seq_len: int max_seq_len: int
...@@ -73,6 +76,7 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -73,6 +76,7 @@ class RotaryPositionEmbedding(torch.nn.Module):
offset: int, default = 0 offset: int, default = 0
Fixed offset for frequencies. Fixed offset for frequencies.
""" """
with torch.autocast(enabled=False, device_type="cuda"):
seq = ( seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset + offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment