Unverified Commit 0b54046f authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

remove unwanted code (#13145)

parent 2e20c0f3
...@@ -753,8 +753,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -753,8 +753,6 @@ class DisentangledSelfAttention(nn.Module):
r_pos = relative_pos r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
if query_layer.size(-2) != key_layer.size(-2):
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
...@@ -763,12 +761,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -763,12 +761,6 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
if query_layer.size(-2) != key_layer.size(-2):
p2c_att = torch.gather(
p2c_att,
dim=-2,
index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))),
)
score += p2c_att / scale score += p2c_att / scale
# position->position # position->position
...@@ -776,12 +768,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -776,12 +768,6 @@ class DisentangledSelfAttention(nn.Module):
pos_query = pos_query_layer[:, :, att_span:, :] pos_query = pos_query_layer[:, :, att_span:, :]
p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2)) p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:]) p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
if query_layer.size(-2) != key_layer.size(-2):
p2p_att = torch.gather(
p2p_att,
dim=-2,
index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))),
)
p2p_att = torch.gather( p2p_att = torch.gather(
p2p_att, p2p_att,
dim=-1, dim=-1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment