Unverified Commit f8265387 authored by Arman Cohan's avatar Arman Cohan Committed by GitHub
Browse files

create tensors on device (#12846)

parent fbf468b0
...@@ -393,15 +393,18 @@ class T5Attention(nn.Module): ...@@ -393,15 +393,18 @@ class T5Attention(nn.Module):
def compute_bias(self, query_length, key_length): def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long)[:, None] context_position = torch.arange(
memory_position = torch.arange(key_length, dtype=torch.long)[None, :] query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[:, None]
memory_position = torch.arange(
key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length) relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder), bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets, num_buckets=self.relative_attention_num_buckets,
) )
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values return values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment