Unverified Commit 375b4e09 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `cos_sin` device issue in Falcon model (#26448)



* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a7e0ed82
...@@ -129,6 +129,11 @@ class FalconRotaryEmbedding(nn.Module): ...@@ -129,6 +129,11 @@ class FalconRotaryEmbedding(nn.Module):
total_length = seq_len + past_key_values_length total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached: if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype) self._set_cos_sin_cache(total_length, device, dtype)
# the cached tensors need to update their devices (for example, after we change the model's device)
self.cos_cached = self.cos_cached.to(device)
self.sin_cached = self.sin_cached.to(device)
# Gather cos, sin at the designated position ids # Gather cos, sin at the designated position ids
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment