Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
db07a725
"vscode:/vscode.git/clone" did not exist on "ca9cf75f8e02aca91f995db34b03d8f229a85fe3"
Commit
db07a725
authored
Jul 23, 2021
by
Tianqi Liu
Committed by
A. Unique TensorFlower
Jul 23, 2021
Browse files
Internal change
PiperOrigin-RevId: 386580562
parent
688a92a8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
3 deletions
+10
-3
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+6
-2
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+4
-1
No files found.
official/nlp/modeling/models/bert_classifier.py
View file @
db07a725
...
...
@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler'
, 'head_name'
) will be ignored.
"""
def
__init__
(
self
,
...
...
@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer
=
'glorot_uniform'
,
dropout_rate
=
0.1
,
use_encoder_pooler
=
True
,
head_name
=
'sentence_prediction'
,
cls_head
=
None
,
**
kwargs
):
self
.
num_classes
=
num_classes
self
.
head_name
=
head_name
self
.
initializer
=
initializer
self
.
use_encoder_pooler
=
use_encoder_pooler
...
...
@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes
=
num_classes
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
name
=
'sentence_prediction'
)
name
=
head_name
)
predictions
=
classifier
(
cls_inputs
)
...
...
@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return
{
'network'
:
self
.
_network
,
'num_classes'
:
self
.
num_classes
,
'head_name'
:
self
.
head_name
,
'initializer'
:
self
.
initializer
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'cls_head'
:
self
.
_cls_head
,
...
...
official/nlp/modeling/models/xlnet.py
View file @
db07a725
...
...
@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
head_name: Name of the classification head.
"""
def
__init__
(
...
...
@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'random_normal'
,
summary_type
:
str
=
'last'
,
dropout_rate
:
float
=
0.1
,
head_name
:
str
=
'sentence_prediction'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_network
=
network
...
...
@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
'num_classes'
:
num_classes
,
'summary_type'
:
summary_type
,
'dropout_rate'
:
dropout_rate
,
'head_name'
:
head_name
,
}
if
summary_type
==
'last'
:
...
...
@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
cls_token_idx
=
cls_token_idx
,
name
=
'sentence_prediction'
)
name
=
head_name
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_word_ids'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment