Unverified Commit 8eec2004 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Disable torch autocast context in rope forward pass (#2240)


Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7ad130ef
...@@ -373,3 +373,19 @@ def test_fused_qkv_rope( ...@@ -373,3 +373,19 @@ def test_fused_qkv_rope(
if not isinstance(start_positions, torch.Tensor): if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused) torch.testing.assert_close(grad_fused, grad_unfused)
def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast():
rope_layer = RotaryPositionEmbedding(128)
rope_embeddings_no_autocast = rope_layer(max_seq_len=1024)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
rope_embeddings_autocast = rope_layer(max_seq_len=1024)
torch.testing.assert_close(
rope_embeddings_no_autocast.to(dtype=torch.bfloat16),
rope_embeddings_autocast.to(dtype=torch.bfloat16),
atol=1e-8,
rtol=1e-8,
)
...@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module):
""" """
Create rotary position embedding frequencies. Create rotary position embedding frequencies.
This function is particularly sensitive to the use of mixed precision, so we disable the
autocast context if it is enabled.
Parameters Parameters
---------- ----------
max_seq_len: int max_seq_len: int
...@@ -73,26 +76,27 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -73,26 +76,27 @@ class RotaryPositionEmbedding(torch.nn.Module):
offset: int, default = 0 offset: int, default = 0
Fixed offset for frequencies. Fixed offset for frequencies.
""" """
seq = ( with torch.autocast(enabled=False, device_type="cuda"):
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) seq = (
+ offset torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
) + offset
)
if (
self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None
):
if ( if (
max_seq_len self.pretrained_max_position_embeddings is not None
> self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor and self.seq_len_interpolation_factor is not None
): ):
# dynamic linear scaling (length > position we have learned) if (
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) max_seq_len
else: > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
# fixed linear scaling ):
seq *= 1 / self.seq_len_interpolation_factor # dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components, # first part even vector components, second part odd vector components,
# 2 * dim in dimension size # 2 * dim in dimension size
if not self.interleaved: if not self.interleaved:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment