Unverified Commit 6d93d353 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[BugFix] tensor.get_device() -> tensor.device (#3604)

parent 837e1851
...@@ -108,7 +108,7 @@ class RotaryEmbedding(nn.Module): ...@@ -108,7 +108,7 @@ class RotaryEmbedding(nn.Module):
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets) cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions] if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
...@@ -142,7 +142,7 @@ class RotaryEmbedding(nn.Module): ...@@ -142,7 +142,7 @@ class RotaryEmbedding(nn.Module):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
# ops.rotary_embedding()/batched_rotary_embedding() # ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors. # are in-place operations that update the query and key tensors.
if offsets is not None: if offsets is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment