Commit 71f674ae authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Customize base, support seqlen_offset

parent d6ef701a
...@@ -136,20 +136,20 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -136,20 +136,20 @@ class RotaryEmbedding(torch.nn.Module):
""" """
def __init__(self, dim_model: int, *_, **__): def __init__(self, dim: int, base=10000, *_, **__):
super().__init__() super().__init__()
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = 0 self._seq_len_cached = 0
self._cos_cached = None self._cos_cached = None
self._sin_cached = None self._sin_cached = None
def _update_cos_sin_cache(self, x): def _update_cos_sin_cache(self, x, seqlen_offset=0):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim) """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
""" """
seqlen = x.shape[1] seqlen = x.shape[1] + seqlen_offset
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
...@@ -162,6 +162,11 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -162,6 +162,11 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_cached = torch.cos(freqs).to(x.dtype) self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(x.dtype)
def forward(self, qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv) """
return apply_rotary_emb_qkv_(qkv, self._cos_cached, self._sin_cached) seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(qkv, seqlen_offset)
return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment