"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "153d1361c7dcc91c7735cae73e1f594cfcab3e21"
Unverified Commit 2a56edb3 authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

Updated deberta attention (#14625)



* Removed unused p2p attention handling

* Updated DeBERTa configuration

* Updated TF DeBERTa attention

* Rolled back accidental comment deletion
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 824fd44f
...@@ -82,8 +82,8 @@ class DebertaConfig(PretrainedConfig): ...@@ -82,8 +82,8 @@ class DebertaConfig(PretrainedConfig):
position_biased_input (`bool`, *optional*, defaults to `True`): position_biased_input (`bool`, *optional*, defaults to `True`):
Whether add absolute position embedding to content embedding. Whether add absolute position embedding to content embedding.
pos_att_type (`List[str]`, *optional*): pos_att_type (`List[str]`, *optional*):
The type of relative position attention, it can be a combination of `["p2c", "c2p", "p2p"]`, e.g. The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g.
`["p2c"]`, `["p2c", "c2p"]`, `["p2c", "c2p", 'p2p"]`. `["p2c"]`, `["p2c", "c2p"]`.
layer_norm_eps (`float`, optional, defaults to 1e-12): layer_norm_eps (`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
""" """
......
...@@ -551,9 +551,9 @@ class DisentangledSelfAttention(nn.Module): ...@@ -551,9 +551,9 @@ class DisentangledSelfAttention(nn.Module):
self.max_relative_positions = config.max_position_embeddings self.max_relative_positions = config.max_position_embeddings
self.pos_dropout = StableDropout(config.hidden_dropout_prob) self.pos_dropout = StableDropout(config.hidden_dropout_prob)
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False) self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size) self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = StableDropout(config.attention_probs_dropout_prob) self.dropout = StableDropout(config.attention_probs_dropout_prob)
...@@ -671,39 +671,35 @@ class DisentangledSelfAttention(nn.Module): ...@@ -671,39 +671,35 @@ class DisentangledSelfAttention(nn.Module):
rel_embeddings = rel_embeddings[ rel_embeddings = rel_embeddings[
self.max_relative_positions - att_span : self.max_relative_positions + att_span, : self.max_relative_positions - att_span : self.max_relative_positions + att_span, :
].unsqueeze(0) ].unsqueeze(0)
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
pos_key_layer = self.pos_proj(rel_embeddings)
pos_key_layer = self.transpose_for_scores(pos_key_layer)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer)
score = 0 score = 0
# content->position # content->position
if "c2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
pos_key_layer = self.pos_proj(rel_embeddings)
pos_key_layer = self.transpose_for_scores(pos_key_layer)
c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)) c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
score += c2p_att score += c2p_att
# position->content # position->content
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer)
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
if query_layer.size(-2) != key_layer.size(-2): if query_layer.size(-2) != key_layer.size(-2):
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
else: else:
r_pos = relative_pos r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
if query_layer.size(-2) != key_layer.size(-2):
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
if "p2c" in self.pos_att_type:
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = torch.gather( p2c_att = torch.gather(
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
).transpose(-1, -2) ).transpose(-1, -2)
if query_layer.size(-2) != key_layer.size(-2): if query_layer.size(-2) != key_layer.size(-2):
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer)) p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
score += p2c_att score += p2c_att
......
...@@ -516,14 +516,14 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -516,14 +516,14 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
if self.max_relative_positions < 1: if self.max_relative_positions < 1:
self.max_relative_positions = config.max_position_embeddings self.max_relative_positions = config.max_position_embeddings
self.pos_dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="pos_dropout") self.pos_dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="pos_dropout")
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
self.pos_proj = tf.keras.layers.Dense( self.pos_proj = tf.keras.layers.Dense(
self.all_head_size, self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
name="pos_proj", name="pos_proj",
use_bias=False, use_bias=False,
) )
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
self.pos_q_proj = tf.keras.layers.Dense( self.pos_q_proj = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="pos_q_proj" self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="pos_q_proj"
) )
...@@ -677,32 +677,28 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -677,32 +677,28 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
rel_embeddings = tf.expand_dims( rel_embeddings = tf.expand_dims(
rel_embeddings[self.max_relative_positions - att_span : self.max_relative_positions + att_span, :], 0 rel_embeddings[self.max_relative_positions - att_span : self.max_relative_positions + att_span, :], 0
) )
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
pos_key_layer = self.pos_proj(rel_embeddings)
pos_key_layer = self.transpose_for_scores(pos_key_layer)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer)
score = 0 score = 0
# content->position # content->position
if "c2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
pos_key_layer = self.pos_proj(rel_embeddings)
pos_key_layer = self.transpose_for_scores(pos_key_layer)
c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 1, 3, 2])) c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 1, 3, 2]))
c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1) c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch_gather(c2p_att, c2p_dynamic_expand(c2p_pos, query_layer, relative_pos), -1) c2p_att = torch_gather(c2p_att, c2p_dynamic_expand(c2p_pos, query_layer, relative_pos), -1)
score += c2p_att score += c2p_att
# position->content # position->content
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer)
pos_query_layer /= tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=tf.float32)) pos_query_layer /= tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=tf.float32))
if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]: if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]:
r_pos = build_relative_position(shape_list(key_layer)[-2], shape_list(key_layer)[-2]) r_pos = build_relative_position(shape_list(key_layer)[-2], shape_list(key_layer)[-2])
else: else:
r_pos = relative_pos r_pos = relative_pos
p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
if "p2c" in self.pos_att_type:
p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 1, 3, 2])) p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 1, 3, 2]))
p2c_att = tf.transpose( p2c_att = tf.transpose(
torch_gather(p2c_att, p2c_dynamic_expand(p2c_pos, query_layer, key_layer), -1), [0, 1, 3, 2] torch_gather(p2c_att, p2c_dynamic_expand(p2c_pos, query_layer, key_layer), -1), [0, 1, 3, 2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment