"...resnet50_tensorflow.git" did not exist on "8c9343b7f453738612e49def1ee9d4f532e6e811"
Unverified Commit b6204c9e authored by Sander Land's avatar Sander Land Committed by GitHub
Browse files

fix warnings in deberta (#19458)



* fix warnings in deberta

* fix copies

* Revert "fix copies"

This reverts commit 324cb3fed11e04190ba7b4662644baa8143b60be.

* fix copies

* fix copies again

* revert changes to whitespace that make style did since it results in an infinite chain of fix-copies

* argh
Co-authored-by: default avatarSander Land <sander@chatdesk.com>
parent de64d671
...@@ -665,8 +665,8 @@ class DisentangledSelfAttention(nn.Module): ...@@ -665,8 +665,8 @@ class DisentangledSelfAttention(nn.Module):
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
qkvb = [None] * 3 qkvb = [None] * 3
q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype)) q = linear(qkvw[0], qkvb[0], query_states.to(dtype=qkvw[0].dtype))
k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)] k, v = [linear(qkvw[i], qkvb[i], hidden_states.to(dtype=qkvw[i].dtype)) for i in range(1, 3)]
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
...@@ -676,7 +676,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -676,7 +676,7 @@ class DisentangledSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1 + len(self.pos_att_type) scale_factor = 1 + len(self.pos_att_type)
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype) query_layer = query_layer / scale.to(dtype=query_layer.dtype)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
...@@ -742,7 +742,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -742,7 +742,7 @@ class DisentangledSelfAttention(nn.Module):
else: else:
r_pos = relative_pos r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype)) p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype))
p2c_att = torch.gather( p2c_att = torch.gather(
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
).transpose(-1, -2) ).transpose(-1, -2)
......
...@@ -742,9 +742,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -742,9 +742,7 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor( attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype)
scale, dtype=query_layer.dtype
)
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
...@@ -826,7 +824,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -826,7 +824,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
) )
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype) score += c2p_att / scale.to(dtype=c2p_att.dtype)
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
...@@ -849,7 +847,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -849,7 +847,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype) score += p2c_att / scale.to(dtype=p2c_att.dtype)
return score return score
......
...@@ -791,9 +791,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -791,9 +791,7 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor( attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype)
scale, dtype=query_layer.dtype
)
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
...@@ -875,7 +873,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -875,7 +873,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
) )
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype) score += c2p_att / scale.to(dtype=c2p_att.dtype)
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
...@@ -898,7 +896,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -898,7 +896,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype) score += p2c_att / scale.to(dtype=p2c_att.dtype)
return score return score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment