"docs/vscode:/vscode.git/clone" did not exist on "0604f4fc000113004223cde8f2db0c850e8a1f2e"
Commit 4912beaa authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 476990304
parent d4f9d872
...@@ -1023,6 +1023,9 @@ class T5TransformerParams: ...@@ -1023,6 +1023,9 @@ class T5TransformerParams:
num_decoder_layers: Optional[int] = None num_decoder_layers: Optional[int] = None
one_hot_embedding: bool = True one_hot_embedding: bool = True
layer_sharing: bool = False layer_sharing: bool = False
# If true, uses one relative embedding for all encoder layers and one for all
# decoder layers. Otherwise, have relative embedding for each layer.
use_shared_relative_position_bias: bool = True
class Encoder(Module): class Encoder(Module):
...@@ -1051,17 +1054,34 @@ class Encoder(Module): ...@@ -1051,17 +1054,34 @@ class Encoder(Module):
self.input_embed = shared_embedding self.input_embed = shared_embedding
# Creates an alias to the input embed for encoder-only models. # Creates an alias to the input embed for encoder-only models.
self.word_embed = self.input_embed self.word_embed = self.input_embed
self.relative_embedding = RelativePositionEmbedding( if config.use_shared_relative_position_bias:
num_heads=self.config.num_heads, self.relative_embedding = RelativePositionEmbedding(
relative_attention_num_buckets=self.config num_heads=self.config.num_heads,
.relative_attention_num_buckets, relative_attention_num_buckets=self.config
relative_attention_max_distance=self.config .relative_attention_num_buckets,
.relative_attention_max_distance, relative_attention_max_distance=self.config
bidirectional=self.config.bidirectional, .relative_attention_max_distance,
embeddings_initializer=self.config.relative_embeddings_initializer, bidirectional=self.config.bidirectional,
dtype=self.dtype, embeddings_initializer=self.config.relative_embeddings_initializer,
compute_dtype=self.compute_dtype, dtype=self.dtype,
name="relative_posemb") compute_dtype=self.compute_dtype,
name="relative_posemb")
else:
self.relative_embeddings = []
for layer_idx in range(self.config.num_layers):
relative_embedding = RelativePositionEmbedding(
num_heads=self.config.num_heads,
relative_attention_num_buckets=self.config
.relative_attention_num_buckets,
relative_attention_max_distance=self.config
.relative_attention_max_distance,
bidirectional=self.config.bidirectional,
embeddings_initializer=self.config
.relative_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name=f"relative_posemb_{layer_idx}")
self.relative_embeddings.append(relative_embedding)
self.input_dropout = Dropout(self.config.dropout_rate,) self.input_dropout = Dropout(self.config.dropout_rate,)
self.encoder_layers = [] self.encoder_layers = []
for layer_idx in range(self.config.num_layers): for layer_idx in range(self.config.num_layers):
...@@ -1088,6 +1108,26 @@ class Encoder(Module): ...@@ -1088,6 +1108,26 @@ class Encoder(Module):
name="final_layer_norm") name="final_layer_norm")
self.output_dropout = Dropout(self.config.dropout_rate,) self.output_dropout = Dropout(self.config.dropout_rate,)
@tf.Module.with_name_scope
def get_relpos_bias(self,
input_length: int,
dense_inputs: tf.Tensor,
layer_idx: Optional[int] = None) -> tf.Tensor:
if self.config.use_shared_relative_position_bias:
position_bias = self.relative_embedding(input_length, input_length)
else:
position_bias = self.relative_embeddings[layer_idx](input_length,
input_length)
if dense_inputs is not None:
# Here we ignore relative position bias for dense embeddings.
# TODO(yejiayu): If we proceed to video use cases, rework this part.
dense_input_length = tf_utils.get_shape_list(dense_inputs)[1]
# Position bias shape: [batch, 1, len, len]
paddings = tf.constant([[0, 0], [0, 0], [0, dense_input_length],
[0, dense_input_length]])
position_bias = tf.pad(position_bias, paddings, "CONSTANT")
return position_bias
@tf.Module.with_name_scope @tf.Module.with_name_scope
def __call__(self, def __call__(self,
inputs=None, inputs=None,
...@@ -1127,17 +1167,9 @@ class Encoder(Module): ...@@ -1127,17 +1167,9 @@ class Encoder(Module):
input_length = tf_utils.get_shape_list(inputs)[1] input_length = tf_utils.get_shape_list(inputs)[1]
else: else:
input_length = 0 input_length = 0
position_bias = self.relative_embedding(input_length, input_length)
if dense_inputs is not None:
# Here we ignore relative position bias for dense embeddings.
# TODO(yejiayu): If we proceed to video use cases, rework this part.
dense_input_length = tf_utils.get_shape_list(dense_inputs)[1]
# Position bias shape: [batch, 1, len, len]
paddings = tf.constant([[0, 0], [0, 0], [0, dense_input_length],
[0, dense_input_length]])
position_bias = tf.pad(position_bias, paddings, "CONSTANT")
for i in range(cfg.num_layers): for i in range(cfg.num_layers):
position_bias = self.get_relpos_bias(input_length, dense_inputs, i)
x = self.encoder_layers[i]( x = self.encoder_layers[i](
x, x,
attention_mask=encoder_mask, attention_mask=encoder_mask,
...@@ -1180,17 +1212,34 @@ class Decoder(Module): ...@@ -1180,17 +1212,34 @@ class Decoder(Module):
self.target_embed = shared_embedding self.target_embed = shared_embedding
self.target_dropout = Dropout(self.config.dropout_rate,) self.target_dropout = Dropout(self.config.dropout_rate,)
# Position bias for the target self attention. # Position bias for the target self attention.
self.relative_embedding = RelativePositionEmbedding( if config.use_shared_relative_position_bias:
num_heads=self.config.num_heads, self.relative_embedding = RelativePositionEmbedding(
relative_attention_num_buckets=self.config num_heads=self.config.num_heads,
.relative_attention_num_buckets, relative_attention_num_buckets=self.config
relative_attention_max_distance=self.config .relative_attention_num_buckets,
.relative_attention_max_distance, relative_attention_max_distance=self.config
bidirectional=self.config.bidirectional, .relative_attention_max_distance,
embeddings_initializer=self.config.relative_embeddings_initializer, bidirectional=self.config.bidirectional,
dtype=self.dtype, embeddings_initializer=self.config.relative_embeddings_initializer,
compute_dtype=self.compute_dtype, dtype=self.dtype,
name="relative_posemb") compute_dtype=self.compute_dtype,
name="relative_posemb")
else:
self.relative_embeddings = []
for layer_idx in range(self.config.num_decoder_layers):
relative_embedding = RelativePositionEmbedding(
num_heads=self.config.num_heads,
relative_attention_num_buckets=self.config
.relative_attention_num_buckets,
relative_attention_max_distance=self.config
.relative_attention_max_distance,
bidirectional=self.config.bidirectional,
embeddings_initializer=self.config
.relative_embeddings_initializer,
dtype=self.dtype,
compute_dtype=self.compute_dtype,
name=f"relative_posemb_{layer_idx}")
self.relative_embeddings.append(relative_embedding)
self.decoder_layers = [] self.decoder_layers = []
for layer_idx in range(self.config.num_decoder_layers): for layer_idx in range(self.config.num_decoder_layers):
if self.config.layer_sharing and layer_idx > 0: if self.config.layer_sharing and layer_idx > 0:
...@@ -1223,6 +1272,13 @@ class Decoder(Module): ...@@ -1223,6 +1272,13 @@ class Decoder(Module):
dtype=self.dtype, dtype=self.dtype,
name="logits") name="logits")
@tf.Module.with_name_scope
def get_relpos_bias(self, input_length: int, layer_idx: int) -> tf.Tensor:
if self.config.use_shared_relative_position_bias:
return self.relative_embedding(input_length, input_length)
else:
return self.relative_embeddings[layer_idx](input_length, input_length)
@tf.Module.with_name_scope @tf.Module.with_name_scope
def __call__(self, def __call__(self,
decoder_input_tokens, decoder_input_tokens,
...@@ -1266,12 +1322,14 @@ class Decoder(Module): ...@@ -1266,12 +1322,14 @@ class Decoder(Module):
tensor_shape = tf_utils.get_shape_list(x) tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1 tensor_shape[-2] = 1
x = self.target_dropout(x, noise_shape=tensor_shape, training=training) x = self.target_dropout(x, noise_shape=tensor_shape, training=training)
if cache is not None:
position_bias = self.relative_embedding(max_decode_len, max_decode_len)
else:
input_length = tf_utils.get_shape_list(decoder_input_tokens)[1]
position_bias = self.relative_embedding(input_length, input_length)
for i in range(cfg.num_decoder_layers): for i in range(cfg.num_decoder_layers):
if cache is not None:
position_bias = self.get_relpos_bias(max_decode_len, i)
else:
input_length = tf_utils.get_shape_list(decoder_input_tokens)[1]
position_bias = self.get_relpos_bias(input_length, i)
if cache is None: if cache is None:
x, _ = self.decoder_layers[i]( x, _ = self.decoder_layers[i](
x, x,
......
...@@ -484,7 +484,7 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): ...@@ -484,7 +484,7 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(v.dtype, tf.float32) self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters( @parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),) ("t5_10_dense", ("relu",), True, 26, False, tf.float32),)
def test_transformer_with_dense(self, ffn_activations, logits_via_embedding, def test_transformer_with_dense(self, ffn_activations, logits_via_embedding,
expect_num_variables, layer_sharing, dtype): expect_num_variables, layer_sharing, dtype):
max_decode_len = 10 max_decode_len = 10
...@@ -500,6 +500,7 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): ...@@ -500,6 +500,7 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
ffn_activations=ffn_activations, ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding) logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype) transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables) self.assertLen(transformer.trainable_variables, expect_num_variables)
inputs = tf.convert_to_tensor( inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])) np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
...@@ -535,6 +536,75 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): ...@@ -535,6 +536,75 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
print(v.name, v.shape) print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32) self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10_dense_layerwise_relpos",
("relu",), True, 26, False, tf.float32, False, 1),
("t5_10_dense_shared_relpos_d2",
("relu",), True, 39, False, tf.float32, True, 2),
("t5_10_dense_layerwise_relpos_d2",
("relu",), True, 40, False, tf.float32, False, 2),
)
def test_transformer_with_lw_relpos(self, ffn_activations,
logits_via_embedding,
expect_num_variables, layer_sharing,
dtype, use_shared_relpos,
num_decoder_layers):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=1,
num_decoder_layers=num_decoder_layers,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
layer_sharing=layer_sharing,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding,
use_shared_relative_position_bias=use_shared_relpos)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
outputs = transformer(
encoder_input_tokens=inputs,
encoder_dense_inputs=dense_inputs,
decoder_input_tokens=inputs,
decoder_target_tokens=inputs,
encoder_segment_ids=segments,
encoder_dense_segment_ids=dense_segments,
decoder_segment_ids=segments)
cache = {}
batch_size = 2
for i in range(num_decoder_layers):
cache[i] = _create_cache(
batch_size,
max_decode_len,
config.num_heads,
config.d_kv,
dtype=dtype)
outputs = transformer.decode(
encoder_input_tokens=inputs,
encoder_dense_inputs=dense_inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters( @parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),) ("t5_10", ("relu",), True, 26, False, tf.float32),)
def test_transformer_with_dense_only(self, ffn_activations, def test_transformer_with_dense_only(self, ffn_activations,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment