Unverified Commit 86d0b26d authored by Jingya HUANG's avatar Jingya HUANG Committed by GitHub
Browse files

Fix matmul inputs dtype (#18585)

parent c99e9846
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" PyTorch DeBERTa model.""" """ PyTorch DeBERTa model."""
import math
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -640,8 +639,8 @@ class DisentangledSelfAttention(nn.Module): ...@@ -640,8 +639,8 @@ class DisentangledSelfAttention(nn.Module):
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
qkvb = [None] * 3 qkvb = [None] * 3
q = linear(qkvw[0], qkvb[0], query_states) q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype))
k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)]
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
...@@ -650,8 +649,8 @@ class DisentangledSelfAttention(nn.Module): ...@@ -650,8 +649,8 @@ class DisentangledSelfAttention(nn.Module):
rel_att = None rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1 + len(self.pos_att_type) scale_factor = 1 + len(self.pos_att_type)
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
query_layer = query_layer / scale query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
...@@ -711,13 +710,13 @@ class DisentangledSelfAttention(nn.Module): ...@@ -711,13 +710,13 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings) pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer) pos_query_layer = self.transpose_for_scores(pos_query_layer)
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if query_layer.size(-2) != key_layer.size(-2): if query_layer.size(-2) != key_layer.size(-2):
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
else: else:
r_pos = relative_pos r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))
p2c_att = torch.gather( p2c_att = torch.gather(
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
).transpose(-1, -2) ).transpose(-1, -2)
......
...@@ -717,7 +717,9 @@ class DisentangledSelfAttention(nn.Module): ...@@ -717,7 +717,9 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
scale, dtype=query_layer.dtype
)
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
...@@ -799,7 +801,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -799,7 +801,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
) )
score += c2p_att / scale score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
...@@ -822,7 +824,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -822,7 +824,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / scale score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
return score return score
......
...@@ -791,7 +791,9 @@ class DisentangledSelfAttention(nn.Module): ...@@ -791,7 +791,9 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
scale, dtype=query_layer.dtype
)
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
...@@ -873,7 +875,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -873,7 +875,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
) )
score += c2p_att / scale score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
...@@ -896,7 +898,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -896,7 +898,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / scale score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
return score return score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment