Commit d3b705d2 authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 386580562
parent 601953a4
......@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler', 'head_name') will be ignored.
"""
def __init__(self,
......@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform',
dropout_rate=0.1,
use_encoder_pooler=True,
head_name='sentence_prediction',
cls_head=None,
**kwargs):
self.num_classes = num_classes
self.head_name = head_name
self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler
......@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
name=head_name)
predictions = classifier(cls_inputs)
......@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return {
'network': self._network,
'num_classes': self.num_classes,
'head_name': self.head_name,
'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head,
......
......@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
head_name: Name of the classification head.
"""
def __init__(
......@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last',
dropout_rate: float = 0.1,
head_name: str = 'sentence_prediction',
**kwargs):
super().__init__(**kwargs)
self._network = network
......@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
'num_classes': num_classes,
'summary_type': summary_type,
'dropout_rate': dropout_rate,
'head_name': head_name,
}
if summary_type == 'last':
......@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
initializer=initializer,
dropout_rate=dropout_rate,
cls_token_idx=cls_token_idx,
name='sentence_prediction')
name=head_name)
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment