Commit d872cee2 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 408970557
parent 892dac23
......@@ -170,6 +170,31 @@ class KernelEncoderConfig(hyperparams.Config):
scale: Optional[float] = None
@dataclasses.dataclass
class ReuseEncoderConfig(hyperparams.Config):
"""Reuse encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
# Reuse transformer
reuse_attention: int = -1
use_relative_pe: bool = False
pe_max_seq_length: int = 512
@dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration."""
......@@ -205,6 +230,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
reuse: ReuseEncoderConfig = ReuseEncoderConfig()
teams: BertEncoderConfig = BertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
......@@ -472,6 +498,42 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "reuse":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.intermediate_size,
inner_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
output_dropout=encoder_cfg.dropout_rate,
attention_dropout=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
reuse_attention=encoder_cfg.reuse_attention,
use_relative_pe=encoder_cfg.use_relative_pe,
pe_max_seq_length=encoder_cfg.pe_max_seq_length)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.ReuseTransformer,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
feed_layer_idx=True,
recursive=True)
return networks.EncoderScaffold(**kwargs)
bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2
......
......@@ -102,6 +102,9 @@ class EncoderScaffold(tf.keras.Model):
dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg.
feed_layer_idx: whether the scaffold should feed layer index to hidden_cls.
recursive: whether to pass the second return of the hidden layer as the last
element among the inputs. None will be passed as the initial state.
"""
def __init__(self,
......@@ -120,6 +123,8 @@ class EncoderScaffold(tf.keras.Model):
return_all_layer_outputs=False,
dict_outputs=False,
layer_idx_as_attention_seed=False,
feed_layer_idx=False,
recursive=False,
**kwargs):
if embedding_cls:
......@@ -201,6 +206,8 @@ class EncoderScaffold(tf.keras.Model):
'contain classes or instances with size specified by '
'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls),
num_hidden_instances)
# Consider supporting customized init states.
recursive_states = None
for i in range(num_hidden_instances):
if isinstance(hidden_cls, list):
cur_hidden_cls = hidden_cls[i]
......@@ -211,10 +218,15 @@ class EncoderScaffold(tf.keras.Model):
layer_idx_as_attention_seed):
hidden_cfg = copy.deepcopy(hidden_cfg)
hidden_cfg['attention_cfg']['seed'] = i
if feed_layer_idx:
hidden_cfg['layer_idx'] = i
layer = cur_hidden_cls(**hidden_cfg)
else:
layer = cur_hidden_cls
data = layer([data, attention_mask])
if recursive:
data, recursive_states = layer([data, attention_mask, recursive_states])
else:
data = layer([data, attention_mask])
layer_output_data.append(data)
hidden_layers.append(layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment