Unverified Commit 57713443 authored by Ahmed Elnaggar's avatar Ahmed Elnaggar Committed by GitHub
Browse files

Configurable Relative Position Max. Distance (#16155)



* Configurable Relative Position Max. Distance

* fix missing config
Co-authored-by: default avatarahmed-elnaggar <ahmed.elnaggar@allianz.com>
parent cd1ffb40
......@@ -60,6 +60,8 @@ class T5Config(PretrainedConfig):
Number of attention heads for each attention layer in the Transformer encoder.
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
The number of buckets to use for each attention layer.
relative_attention_max_distance (`int`, *optional*, defaults to 128):
The maximum distance of the longer sequences for the bucket separation.
dropout_rate (`float`, *optional*, defaults to 0.1):
The ratio for all dropout layers.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
......@@ -87,6 +89,7 @@ class T5Config(PretrainedConfig):
num_decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
relative_attention_max_distance=128,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
......@@ -107,6 +110,7 @@ class T5Config(PretrainedConfig):
) # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
......
......@@ -187,6 +187,7 @@ class FlaxT5Attention(nn.Module):
def setup(self):
self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
self.relative_attention_max_distance = self.config.relative_attention_max_distance
self.d_model = self.config.d_model
self.key_value_proj_dim = self.config.d_kv
self.n_heads = self.config.num_heads
......@@ -275,6 +276,7 @@ class FlaxT5Attention(nn.Module):
relative_position,
bidirectional=(not self.causal),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket)
......
......@@ -334,8 +334,8 @@ class T5Attention(nn.Module):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
......@@ -430,6 +430,7 @@ class T5Attention(nn.Module):
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
......
......@@ -179,6 +179,7 @@ class TFT5Attention(tf.keras.layers.Layer):
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
......@@ -285,6 +286,7 @@ class TFT5Attention(tf.keras.layers.Layer):
relative_position,
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = tf.gather(
self.relative_attention_bias, relative_position_bucket
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment