Unverified Commit 03ffd0a0 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add comments on RoPE initialization (#1176)

parent a425bd9a
...@@ -264,6 +264,15 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -264,6 +264,15 @@ class PagedAttentionWithRoPE(PagedAttention):
self.is_neox_style = is_neox_style self.is_neox_style = is_neox_style
# Create the cos and sin cache. # Create the cos and sin cache.
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange( inv_freq = 1.0 / (base**(torch.arange(
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim)) 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
t = torch.arange(max_position, dtype=torch.float, device="cuda") t = torch.arange(max_position, dtype=torch.float, device="cuda")
...@@ -274,7 +283,6 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -274,7 +283,6 @@ class PagedAttentionWithRoPE(PagedAttention):
# FIXME(woosuk): This assumes that we configure the default dtype when # FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. # initializing the model.
# TODO(woosuk): Make it more robust.
torch_dtype = torch.get_default_dtype() torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype) cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim] # Embedding size: [max_position, rotary_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment