Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b2e422b0
Commit
b2e422b0
authored
Jun 03, 2020
by
Maxim Neumann
Committed by
A. Unique TensorFlower
Jun 03, 2020
Browse files
Adjust run_classification to support fine-tuning regression tasks.
PiperOrigin-RevId: 314607393
parent
4bb13e61
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
14 deletions
+39
-14
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+3
-2
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+36
-12
No files found.
official/nlp/bert/input_pipeline.py
View file @
b2e422b0
...
@@ -154,13 +154,14 @@ def create_classifier_dataset(file_path,
...
@@ -154,13 +154,14 @@ def create_classifier_dataset(file_path,
seq_length
,
seq_length
,
batch_size
,
batch_size
,
is_training
=
True
,
is_training
=
True
,
input_pipeline_context
=
None
):
input_pipeline_context
=
None
,
label_type
=
tf
.
int64
):
"""Creates input dataset from (tf)records files for train/eval."""
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
}
dataset
=
single_file_dataset
(
file_path
,
name_to_features
)
dataset
=
single_file_dataset
(
file_path
,
name_to_features
)
...
...
official/nlp/bert/run_classifier.py
View file @
b2e422b0
...
@@ -12,11 +12,12 @@
...
@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""BERT classification finetuning runner in TF 2.x."""
"""BERT classification
or regression
finetuning runner in TF 2.x."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
functools
import
json
import
json
import
math
import
math
import
os
import
os
...
@@ -60,6 +61,8 @@ common_flags.define_common_bert_flags()
...
@@ -60,6 +61,8 @@ common_flags.define_common_bert_flags()
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
def
get_loss_fn
(
num_classes
):
def
get_loss_fn
(
num_classes
):
"""Gets the classification loss function."""
"""Gets the classification loss function."""
...
@@ -77,8 +80,20 @@ def get_loss_fn(num_classes):
...
@@ -77,8 +80,20 @@ def get_loss_fn(num_classes):
return
classification_loss_fn
return
classification_loss_fn
def
get_regression_loss_fn
():
"""Gets the regression loss function."""
def
regression_loss_fn
(
labels
,
logits
):
"""Regression loss."""
labels
=
tf
.
cast
(
labels
,
dtype
=
tf
.
float32
)
per_example_loss
=
tf
.
math
.
squared_difference
(
labels
,
logits
)
return
tf
.
reduce_mean
(
per_example_loss
)
return
regression_loss_fn
def
get_dataset_fn
(
input_file_pattern
,
max_seq_length
,
global_batch_size
,
def
get_dataset_fn
(
input_file_pattern
,
max_seq_length
,
global_batch_size
,
is_training
):
is_training
,
label_type
=
tf
.
int64
):
"""Gets a closure to create a dataset."""
"""Gets a closure to create a dataset."""
def
_dataset_fn
(
ctx
=
None
):
def
_dataset_fn
(
ctx
=
None
):
...
@@ -90,7 +105,8 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
...
@@ -90,7 +105,8 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
max_seq_length
,
max_seq_length
,
batch_size
,
batch_size
,
is_training
=
is_training
,
is_training
=
is_training
,
input_pipeline_context
=
ctx
)
input_pipeline_context
=
ctx
,
label_type
=
label_type
)
return
dataset
return
dataset
return
_dataset_fn
return
_dataset_fn
...
@@ -113,7 +129,8 @@ def run_bert_classifier(strategy,
...
@@ -113,7 +129,8 @@ def run_bert_classifier(strategy,
custom_callbacks
=
None
):
custom_callbacks
=
None
):
"""Run BERT classifier training using low-level API."""
"""Run BERT classifier training using low-level API."""
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
num_classes
=
input_meta_data
[
'num_labels'
]
num_classes
=
input_meta_data
.
get
(
'num_labels'
,
1
)
is_regression
=
num_classes
==
1
def
_get_classifier_model
():
def
_get_classifier_model
():
"""Gets a classifier model."""
"""Gets a classifier model."""
...
@@ -134,13 +151,17 @@ def run_bert_classifier(strategy,
...
@@ -134,13 +151,17 @@ def run_bert_classifier(strategy,
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
return
classifier_model
,
core_model
return
classifier_model
,
core_model
loss_fn
=
get_loss_fn
(
num_classes
)
loss_fn
=
(
get_regression_loss_fn
()
if
is_regression
else
get_loss_fn
(
num_classes
))
# Defines evaluation metrics function, which will create metrics in the
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
# correct device and strategy scope.
def
metric_fn
():
if
is_regression
:
return
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
MeanSquaredError
,
'accuracy'
,
dtype
=
tf
.
float32
)
'mean_squared_error'
,
dtype
=
tf
.
float32
)
else
:
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
,
'accuracy'
,
dtype
=
tf
.
float32
)
# Start training using Keras compile/fit API.
# Start training using Keras compile/fit API.
logging
.
info
(
'Training using TF 2.x Keras compile/fit API with '
logging
.
info
(
'Training using TF 2.x Keras compile/fit API with '
...
@@ -310,7 +331,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
...
@@ -310,7 +331,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision.
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
[
'num_labels'
]
)[
0
]
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
)
)[
0
]
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
...
@@ -371,7 +392,7 @@ def run_bert(strategy,
...
@@ -371,7 +392,7 @@ def run_bert(strategy,
def
custom_main
(
custom_callbacks
=
None
):
def
custom_main
(
custom_callbacks
=
None
):
"""Run classification.
"""Run classification
or regression
.
Args:
Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
custom_callbacks: list of tf.keras.Callbacks passed to training loop.
...
@@ -380,6 +401,7 @@ def custom_main(custom_callbacks=None):
...
@@ -380,6 +401,7 @@ def custom_main(custom_callbacks=None):
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
label_type
=
LABEL_TYPES_MAP
[
input_meta_data
.
get
(
'label_type'
,
'int'
)]
if
not
FLAGS
.
model_dir
:
if
not
FLAGS
.
model_dir
:
FLAGS
.
model_dir
=
'/tmp/bert20/'
FLAGS
.
model_dir
=
'/tmp/bert20/'
...
@@ -399,7 +421,8 @@ def custom_main(custom_callbacks=None):
...
@@ -399,7 +421,8 @@ def custom_main(custom_callbacks=None):
FLAGS
.
eval_data_path
,
FLAGS
.
eval_data_path
,
input_meta_data
[
'max_seq_length'
],
input_meta_data
[
'max_seq_length'
],
FLAGS
.
eval_batch_size
,
FLAGS
.
eval_batch_size
,
is_training
=
False
)
is_training
=
False
,
label_type
=
label_type
)
if
FLAGS
.
mode
==
'predict'
:
if
FLAGS
.
mode
==
'predict'
:
with
strategy
.
scope
():
with
strategy
.
scope
():
...
@@ -432,7 +455,8 @@ def custom_main(custom_callbacks=None):
...
@@ -432,7 +455,8 @@ def custom_main(custom_callbacks=None):
FLAGS
.
train_data_path
,
FLAGS
.
train_data_path
,
input_meta_data
[
'max_seq_length'
],
input_meta_data
[
'max_seq_length'
],
FLAGS
.
train_batch_size
,
FLAGS
.
train_batch_size
,
is_training
=
True
)
is_training
=
True
,
label_type
=
label_type
)
run_bert
(
run_bert
(
strategy
,
strategy
,
input_meta_data
,
input_meta_data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment