Unverified Commit 319cbbe1 authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

Deberta v2 code simplification (#15732)

* Removed spurious substraction

* Fixed condition checking for attention type

* Fixed sew_d copy of DeBERTa v2 attention

* Removed unused `p2p` attention type from DebertaV2-class models

* Fixed docs style
parent 0a5ef036
...@@ -77,8 +77,8 @@ class DebertaV2Config(PretrainedConfig): ...@@ -77,8 +77,8 @@ class DebertaV2Config(PretrainedConfig):
position_biased_input (`bool`, *optional*, defaults to `False`): position_biased_input (`bool`, *optional*, defaults to `False`):
Whether add absolute position embedding to content embedding. Whether add absolute position embedding to content embedding.
pos_att_type (`List[str]`, *optional*): pos_att_type (`List[str]`, *optional*):
The type of relative position attention, it can be a combination of `["p2c", "c2p", "p2p"]`, e.g. The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`,
`["p2c"]`, `["p2c", "c2p"]`, `["p2c", "c2p", 'p2p"]`. `["p2c", "c2p"]`, `["p2c", "c2p"]`.
layer_norm_eps (`float`, optional, defaults to 1e-12): layer_norm_eps (`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
""" """
......
...@@ -629,9 +629,9 @@ class DisentangledSelfAttention(nn.Module): ...@@ -629,9 +629,9 @@ class DisentangledSelfAttention(nn.Module):
self.pos_dropout = StableDropout(config.hidden_dropout_prob) self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if not self.share_att_key: if not self.share_att_key:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = StableDropout(config.attention_probs_dropout_prob) self.dropout = StableDropout(config.attention_probs_dropout_prob)
...@@ -692,8 +692,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -692,8 +692,6 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1 scale_factor += 1
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
if "p2p" in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = math.sqrt(query_layer.size(-1) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention: if self.relative_attention:
...@@ -744,7 +742,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -744,7 +742,7 @@ class DisentangledSelfAttention(nn.Module):
att_span = self.pos_ebd_size att_span = self.pos_ebd_size
relative_pos = relative_pos.long().to(query_layer.device) relative_pos = relative_pos.long().to(query_layer.device)
rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0) rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
if self.share_att_key: if self.share_att_key:
pos_query_layer = self.transpose_for_scores( pos_query_layer = self.transpose_for_scores(
self.query_proj(rel_embeddings), self.num_attention_heads self.query_proj(rel_embeddings), self.num_attention_heads
...@@ -753,13 +751,13 @@ class DisentangledSelfAttention(nn.Module): ...@@ -753,13 +751,13 @@ class DisentangledSelfAttention(nn.Module):
query_layer.size(0) // self.num_attention_heads, 1, 1 query_layer.size(0) // self.num_attention_heads, 1, 1
) )
else: else:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
pos_key_layer = self.transpose_for_scores( pos_key_layer = self.transpose_for_scores(
self.pos_key_proj(rel_embeddings), self.num_attention_heads self.pos_key_proj(rel_embeddings), self.num_attention_heads
).repeat( ).repeat(
query_layer.size(0) // self.num_attention_heads, 1, 1 query_layer.size(0) // self.num_attention_heads, 1, 1
) # .split(self.all_head_size, dim=-1) ) # .split(self.all_head_size, dim=-1)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = self.transpose_for_scores( pos_query_layer = self.transpose_for_scores(
self.pos_query_proj(rel_embeddings), self.num_attention_heads self.pos_query_proj(rel_embeddings), self.num_attention_heads
).repeat( ).repeat(
...@@ -780,7 +778,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -780,7 +778,7 @@ class DisentangledSelfAttention(nn.Module):
score += c2p_att / scale score += c2p_att / scale
# position->content # position->content
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2): if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position( r_pos = build_relative_position(
...@@ -794,8 +792,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -794,8 +792,6 @@ class DisentangledSelfAttention(nn.Module):
r_pos = relative_pos r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
if "p2c" in self.pos_att_type:
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = torch.gather( p2c_att = torch.gather(
p2c_att, p2c_att,
...@@ -804,20 +800,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -804,20 +800,6 @@ class DisentangledSelfAttention(nn.Module):
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / scale score += p2c_att / scale
# position->position
if "p2p" in self.pos_att_type:
pos_query = pos_query_layer[:, :, att_span:, :]
p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
p2p_att = torch.gather(
p2p_att,
dim=-1,
index=c2p_pos.expand(
[query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
),
)
score += p2p_att
return score return score
......
...@@ -604,14 +604,14 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -604,14 +604,14 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout") self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout")
if not self.share_att_key: if not self.share_att_key:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
self.pos_proj = tf.keras.layers.Dense( self.pos_proj = tf.keras.layers.Dense(
self.all_head_size, self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
name="pos_proj", name="pos_proj",
use_bias=True, use_bias=True,
) )
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
self.pos_q_proj = tf.keras.layers.Dense( self.pos_q_proj = tf.keras.layers.Dense(
self.all_head_size, self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
...@@ -679,8 +679,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -679,8 +679,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
scale_factor += 1 scale_factor += 1
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
if "p2p" in self.pos_att_type:
scale_factor += 1
scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32)) scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32))
attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1])) / scale attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1])) / scale
if self.relative_attention: if self.relative_attention:
...@@ -749,12 +747,12 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -749,12 +747,12 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
) )
else: else:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
pos_key_layer = tf.tile( pos_key_layer = tf.tile(
self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads), self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads),
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
) # .split(self.all_head_size, dim=-1) ) # .split(self.all_head_size, dim=-1)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = tf.tile( pos_query_layer = tf.tile(
self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads), self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads),
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1], [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
...@@ -777,7 +775,7 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -777,7 +775,7 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
score += c2p_att / scale score += c2p_att / scale
# position->content # position->content
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, tf.float32)) scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, tf.float32))
if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]: if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]:
r_pos = build_relative_position( r_pos = build_relative_position(
...@@ -792,7 +790,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -792,7 +790,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
if "p2c" in self.pos_att_type:
p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1])) p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1]))
p2c_att = tf.transpose( p2c_att = tf.transpose(
take_along_axis( take_along_axis(
...@@ -807,26 +804,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -807,26 +804,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
) )
score += p2c_att / scale score += p2c_att / scale
# position->position
if "p2p" in self.pos_att_type:
pos_query = pos_query_layer[:, :, att_span:, :]
p2p_att = tf.matmul(pos_query, tf.transpose(pos_key_layer, [0, 2, 1]))
p2p_att = tf.broadcast_to(shape_list(query_layer)[:2] + shape_list(p2p_att)[2:])
p2p_att = take_along_axis(
p2p_att,
tf.broadcast_to(
c2p_pos,
[
shape_list(query_layer)[0],
shape_list(query_layer)[1],
shape_list(query_layer)[2],
shape_list(relative_pos)[-1],
],
),
-1,
)
score += p2p_att
return score return score
......
...@@ -66,8 +66,8 @@ class SEWDConfig(PretrainedConfig): ...@@ -66,8 +66,8 @@ class SEWDConfig(PretrainedConfig):
position_biased_input (`bool`, *optional*, defaults to `False`): position_biased_input (`bool`, *optional*, defaults to `False`):
Whether to add absolute position embedding to content embedding. Whether to add absolute position embedding to content embedding.
pos_att_type (`Tuple[str]`, *optional*, defaults to `("p2c", "c2p")`): pos_att_type (`Tuple[str]`, *optional*, defaults to `("p2c", "c2p")`):
The type of relative position attention, it can be a combination of `("p2c", "c2p", "p2p")`, e.g. The type of relative position attention, it can be a combination of `("p2c", "c2p")`, e.g. `("p2c")`,
`("p2c")`, `("p2c", "c2p")`, `("p2c", "c2p", 'p2p")`. `("p2c", "c2p")`, `("p2c", "c2p")`.
norm_rel_ebd (`str`, *optional*, defaults to `"layer_norm"`): norm_rel_ebd (`str`, *optional*, defaults to `"layer_norm"`):
Whether to use layer norm in relative embedding (`"layer_norm"` if yes) Whether to use layer norm in relative embedding (`"layer_norm"` if yes)
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`): hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
......
...@@ -703,9 +703,9 @@ class DisentangledSelfAttention(nn.Module): ...@@ -703,9 +703,9 @@ class DisentangledSelfAttention(nn.Module):
self.pos_dropout = StableDropout(config.activation_dropout) self.pos_dropout = StableDropout(config.activation_dropout)
if not self.share_att_key: if not self.share_att_key:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = StableDropout(config.attention_dropout) self.dropout = StableDropout(config.attention_dropout)
...@@ -766,8 +766,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -766,8 +766,6 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1 scale_factor += 1
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
if "p2p" in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = math.sqrt(query_layer.size(-1) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention: if self.relative_attention:
...@@ -818,7 +816,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -818,7 +816,7 @@ class DisentangledSelfAttention(nn.Module):
att_span = self.pos_ebd_size att_span = self.pos_ebd_size
relative_pos = relative_pos.long().to(query_layer.device) relative_pos = relative_pos.long().to(query_layer.device)
rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0) rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
if self.share_att_key: if self.share_att_key:
pos_query_layer = self.transpose_for_scores( pos_query_layer = self.transpose_for_scores(
self.query_proj(rel_embeddings), self.num_attention_heads self.query_proj(rel_embeddings), self.num_attention_heads
...@@ -827,13 +825,13 @@ class DisentangledSelfAttention(nn.Module): ...@@ -827,13 +825,13 @@ class DisentangledSelfAttention(nn.Module):
query_layer.size(0) // self.num_attention_heads, 1, 1 query_layer.size(0) // self.num_attention_heads, 1, 1
) )
else: else:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
pos_key_layer = self.transpose_for_scores( pos_key_layer = self.transpose_for_scores(
self.pos_key_proj(rel_embeddings), self.num_attention_heads self.pos_key_proj(rel_embeddings), self.num_attention_heads
).repeat( ).repeat(
query_layer.size(0) // self.num_attention_heads, 1, 1 query_layer.size(0) // self.num_attention_heads, 1, 1
) # .split(self.all_head_size, dim=-1) ) # .split(self.all_head_size, dim=-1)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = self.transpose_for_scores( pos_query_layer = self.transpose_for_scores(
self.pos_query_proj(rel_embeddings), self.num_attention_heads self.pos_query_proj(rel_embeddings), self.num_attention_heads
).repeat( ).repeat(
...@@ -854,7 +852,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -854,7 +852,7 @@ class DisentangledSelfAttention(nn.Module):
score += c2p_att / scale score += c2p_att / scale
# position->content # position->content
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2): if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position( r_pos = build_relative_position(
...@@ -868,8 +866,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -868,8 +866,6 @@ class DisentangledSelfAttention(nn.Module):
r_pos = relative_pos r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
if "p2c" in self.pos_att_type:
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = torch.gather( p2c_att = torch.gather(
p2c_att, p2c_att,
...@@ -878,20 +874,6 @@ class DisentangledSelfAttention(nn.Module): ...@@ -878,20 +874,6 @@ class DisentangledSelfAttention(nn.Module):
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / scale score += p2c_att / scale
# position->position
if "p2p" in self.pos_att_type:
pos_query = pos_query_layer[:, :, att_span:, :]
p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
p2p_att = torch.gather(
p2p_att,
dim=-1,
index=c2p_pos.expand(
[query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
),
)
score += p2p_att
return score return score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment