Unverified Commit 72ad03ea authored by Tri Dao's avatar Tri Dao Committed by GitHub
Browse files

Merge pull request #299 from proger/rotary-inference-mode

rotary: update cos/sin cache when switching from inference mode
parents 2800efc7 70ab266a
...@@ -211,9 +211,11 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -211,9 +211,11 @@ class RotaryEmbedding(torch.nn.Module):
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device if (seqlen > self._seq_len_cached or self._cos_cached.device != device
or self._cos_cached.dtype != dtype): or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision. # And the output of arange can be quite large, so bf16 would lose a lot of precision.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment