Commit b12db2db authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 436647758
parent 3e53adfe
...@@ -1004,6 +1004,7 @@ class T5TransformerParams: ...@@ -1004,6 +1004,7 @@ class T5TransformerParams:
num_heads: int num_heads: int
d_ff: int d_ff: int
vocab_size: int vocab_size: int
target_vocab_size: Optional[int] = None
dropout_rate: float = 0.0 dropout_rate: float = 0.0
layer_norm_epsilon: float = 1e-6 layer_norm_epsilon: float = 1e-6
shared_embedding: bool = False shared_embedding: bool = False
...@@ -1159,11 +1160,15 @@ class Decoder(Module): ...@@ -1159,11 +1160,15 @@ class Decoder(Module):
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
if self.config.num_decoder_layers is None: if self.config.num_decoder_layers is None:
self.config.num_decoder_layers = self.config.num_layers self.config.num_decoder_layers = self.config.num_layers
if not hasattr(
self.config,
"target_vocab_size") or self.config.target_vocab_size is None:
self.config.target_vocab_size = self.config.vocab_size
with self.name_scope: with self.name_scope:
# Target Embedding. # Target Embedding.
if shared_embedding is None: if shared_embedding is None:
self.target_embed = Embed( self.target_embed = Embed(
vocab_size=self.config.vocab_size, vocab_size=self.config.target_vocab_size,
features=self.config.d_model, features=self.config.d_model,
embeddings_initializer=self.config.vocab_embeddings_initializer, embeddings_initializer=self.config.vocab_embeddings_initializer,
dtype=self.dtype, dtype=self.dtype,
...@@ -1211,7 +1216,7 @@ class Decoder(Module): ...@@ -1211,7 +1216,7 @@ class Decoder(Module):
if not self.config.logits_via_embedding: if not self.config.logits_via_embedding:
self.logits_dense = Linear( self.logits_dense = Linear(
in_features=self.config.d_model, in_features=self.config.d_model,
out_features=self.config.vocab_size, out_features=self.config.target_vocab_size,
use_bias=False, use_bias=False,
dtype=self.dtype, dtype=self.dtype,
name="logits") name="logits")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment