Commit d872cee2 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 408970557
parent 892dac23
...@@ -170,6 +170,31 @@ class KernelEncoderConfig(hyperparams.Config): ...@@ -170,6 +170,31 @@ class KernelEncoderConfig(hyperparams.Config):
scale: Optional[float] = None scale: Optional[float] = None
@dataclasses.dataclass
class ReuseEncoderConfig(hyperparams.Config):
"""Reuse encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
# Reuse transformer
reuse_attention: int = -1
use_relative_pe: bool = False
pe_max_seq_length: int = 512
@dataclasses.dataclass @dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config): class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration.""" """XLNet encoder configuration."""
...@@ -205,6 +230,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -205,6 +230,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig() bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig() kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
reuse: ReuseEncoderConfig = ReuseEncoderConfig()
teams: BertEncoderConfig = BertEncoderConfig() teams: BertEncoderConfig = BertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig() xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
...@@ -472,6 +498,42 @@ def build_encoder(config: EncoderConfig, ...@@ -472,6 +498,42 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True) dict_outputs=True)
return networks.EncoderScaffold(**kwargs) return networks.EncoderScaffold(**kwargs)
if encoder_type == "reuse":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.intermediate_size,
inner_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
output_dropout=encoder_cfg.dropout_rate,
attention_dropout=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
reuse_attention=encoder_cfg.reuse_attention,
use_relative_pe=encoder_cfg.use_relative_pe,
pe_max_seq_length=encoder_cfg.pe_max_seq_length)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.ReuseTransformer,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
feed_layer_idx=True,
recursive=True)
return networks.EncoderScaffold(**kwargs)
bert_encoder_cls = networks.BertEncoder bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2": if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2 bert_encoder_cls = networks.BertEncoderV2
......
...@@ -102,6 +102,9 @@ class EncoderScaffold(tf.keras.Model): ...@@ -102,6 +102,9 @@ class EncoderScaffold(tf.keras.Model):
dict_outputs: Whether to use a dictionary as the model outputs. dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg. attention_cfg in hidden_cfg.
feed_layer_idx: whether the scaffold should feed layer index to hidden_cls.
recursive: whether to pass the second return of the hidden layer as the last
element among the inputs. None will be passed as the initial state.
""" """
def __init__(self, def __init__(self,
...@@ -120,6 +123,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -120,6 +123,8 @@ class EncoderScaffold(tf.keras.Model):
return_all_layer_outputs=False, return_all_layer_outputs=False,
dict_outputs=False, dict_outputs=False,
layer_idx_as_attention_seed=False, layer_idx_as_attention_seed=False,
feed_layer_idx=False,
recursive=False,
**kwargs): **kwargs):
if embedding_cls: if embedding_cls:
...@@ -201,6 +206,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -201,6 +206,8 @@ class EncoderScaffold(tf.keras.Model):
'contain classes or instances with size specified by ' 'contain classes or instances with size specified by '
'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls), 'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls),
num_hidden_instances) num_hidden_instances)
# Consider supporting customized init states.
recursive_states = None
for i in range(num_hidden_instances): for i in range(num_hidden_instances):
if isinstance(hidden_cls, list): if isinstance(hidden_cls, list):
cur_hidden_cls = hidden_cls[i] cur_hidden_cls = hidden_cls[i]
...@@ -211,10 +218,15 @@ class EncoderScaffold(tf.keras.Model): ...@@ -211,10 +218,15 @@ class EncoderScaffold(tf.keras.Model):
layer_idx_as_attention_seed): layer_idx_as_attention_seed):
hidden_cfg = copy.deepcopy(hidden_cfg) hidden_cfg = copy.deepcopy(hidden_cfg)
hidden_cfg['attention_cfg']['seed'] = i hidden_cfg['attention_cfg']['seed'] = i
if feed_layer_idx:
hidden_cfg['layer_idx'] = i
layer = cur_hidden_cls(**hidden_cfg) layer = cur_hidden_cls(**hidden_cfg)
else: else:
layer = cur_hidden_cls layer = cur_hidden_cls
data = layer([data, attention_mask]) if recursive:
data, recursive_states = layer([data, attention_mask, recursive_states])
else:
data = layer([data, attention_mask])
layer_output_data.append(data) layer_output_data.append(data)
hidden_layers.append(layer) hidden_layers.append(layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment