"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7e7fc53da5f230db379ece739457c81b2f50f13e"
Unverified Commit b1e1a9f9 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2495 from mschrimpf/patch-1

T5: move rp_bucket to relative_attention_bias' device
parents 331065e6 90d3b787
...@@ -286,6 +286,7 @@ class T5Attention(nn.Module): ...@@ -286,6 +286,7 @@ class T5Attention(nn.Module):
bidirectional=not self.is_decoder, bidirectional=not self.is_decoder,
num_buckets=self.relative_attention_num_buckets, num_buckets=self.relative_attention_num_buckets,
) )
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
return values return values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment