Commit b12db2db authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 436647758
parent 3e53adfe
......@@ -1004,6 +1004,7 @@ class T5TransformerParams:
num_heads: int
d_ff: int
vocab_size: int
target_vocab_size: Optional[int] = None
dropout_rate: float = 0.0
layer_norm_epsilon: float = 1e-6
shared_embedding: bool = False
......@@ -1159,11 +1160,15 @@ class Decoder(Module):
self.compute_dtype = compute_dtype
if self.config.num_decoder_layers is None:
self.config.num_decoder_layers = self.config.num_layers
if not hasattr(
self.config,
"target_vocab_size") or self.config.target_vocab_size is None:
self.config.target_vocab_size = self.config.vocab_size
with self.name_scope:
# Target Embedding.
if shared_embedding is None:
self.target_embed = Embed(
vocab_size=self.config.vocab_size,
vocab_size=self.config.target_vocab_size,
features=self.config.d_model,
embeddings_initializer=self.config.vocab_embeddings_initializer,
dtype=self.dtype,
......@@ -1211,7 +1216,7 @@ class Decoder(Module):
if not self.config.logits_via_embedding:
self.logits_dense = Linear(
in_features=self.config.d_model,
out_features=self.config.vocab_size,
out_features=self.config.target_vocab_size,
use_bias=False,
dtype=self.dtype,
name="logits")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment