Commit d3b705d2 authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 386580562
parent 601953a4
...@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model): ...@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder. encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head. cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits. It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate', If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored. 'use_encoder_pooler', 'head_name') will be ignored.
""" """
def __init__(self, def __init__(self,
...@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model): ...@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True, use_encoder_pooler=True,
head_name='sentence_prediction',
cls_head=None, cls_head=None,
**kwargs): **kwargs):
self.num_classes = num_classes self.num_classes = num_classes
self.head_name = head_name
self.initializer = initializer self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler self.use_encoder_pooler = use_encoder_pooler
...@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model): ...@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
name='sentence_prediction') name=head_name)
predictions = classifier(cls_inputs) predictions = classifier(cls_inputs)
...@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model): ...@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return { return {
'network': self._network, 'network': self._network,
'num_classes': self.num_classes, 'num_classes': self.num_classes,
'head_name': self.head_name,
'initializer': self.initializer, 'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler, 'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head, 'cls_head': self._cls_head,
......
...@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
Defaults to a RandomNormal initializer. Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector. summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
head_name: Name of the classification head.
""" """
def __init__( def __init__(
...@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer: tf.keras.initializers.Initializer = 'random_normal', initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last', summary_type: str = 'last',
dropout_rate: float = 0.1, dropout_rate: float = 0.1,
head_name: str = 'sentence_prediction',
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._network = network self._network = network
...@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
'num_classes': num_classes, 'num_classes': num_classes,
'summary_type': summary_type, 'summary_type': summary_type,
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
'head_name': head_name,
} }
if summary_type == 'last': if summary_type == 'last':
...@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
cls_token_idx=cls_token_idx, cls_token_idx=cls_token_idx,
name='sentence_prediction') name=head_name)
def call(self, inputs: Mapping[str, Any]): def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids'] input_ids = inputs['input_word_ids']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment