Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cf80ed4e
Commit
cf80ed4e
authored
Aug 02, 2021
by
anivegesana
Browse files
Merge branch 'purdue-yolo' of
https://github.com/tensorflow/models
into detection_generator_pr_2
parents
394cefcc
461b3587
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
171 additions
and
110 deletions
+171
-110
official/core/config_definitions.py
official/core/config_definitions.py
+2
-1
official/core/train_utils.py
official/core/train_utils.py
+21
-8
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+2
-1
official/modeling/multitask/evaluator.py
official/modeling/multitask/evaluator.py
+22
-21
official/modeling/multitask/evaluator_test.py
official/modeling/multitask/evaluator_test.py
+3
-8
official/modeling/multitask/interleaving_trainer.py
official/modeling/multitask/interleaving_trainer.py
+11
-1
official/modeling/multitask/interleaving_trainer_test.py
official/modeling/multitask/interleaving_trainer_test.py
+1
-0
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+3
-9
official/modeling/multitask/train_lib.py
official/modeling/multitask/train_lib.py
+13
-7
official/modeling/multitask/train_lib_test.py
official/modeling/multitask/train_lib_test.py
+11
-10
official/nlp/continuous_finetune_lib.py
official/nlp/continuous_finetune_lib.py
+4
-2
official/nlp/metrics/bleu.py
official/nlp/metrics/bleu.py
+15
-14
official/nlp/modeling/__init__.py
official/nlp/modeling/__init__.py
+1
-0
official/nlp/modeling/layers/mobile_bert_layers.py
official/nlp/modeling/layers/mobile_bert_layers.py
+19
-0
official/nlp/modeling/layers/mobile_bert_layers_test.py
official/nlp/modeling/layers/mobile_bert_layers_test.py
+16
-0
official/nlp/modeling/layers/spectral_normalization.py
official/nlp/modeling/layers/spectral_normalization.py
+13
-10
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+6
-2
official/nlp/modeling/models/bert_classifier_test.py
official/nlp/modeling/models/bert_classifier_test.py
+2
-4
official/nlp/modeling/models/bert_pretrainer_test.py
official/nlp/modeling/models/bert_pretrainer_test.py
+4
-8
official/nlp/modeling/models/bert_span_labeler_test.py
official/nlp/modeling/models/bert_span_labeler_test.py
+2
-4
No files found.
official/core/config_definitions.py
View file @
cf80ed4e
...
@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
...
@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
class
TaskConfig
(
base_config
.
Config
):
init_checkpoint
:
str
=
""
init_checkpoint
:
str
=
""
model
:
base_config
.
Config
=
None
model
:
Optional
[
base_config
.
Config
]
=
None
train_data
:
DataConfig
=
DataConfig
()
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
name
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/core/train_utils.py
View file @
cf80ed4e
...
@@ -142,14 +142,19 @@ class BestCheckpointExporter:
...
@@ -142,14 +142,19 @@ class BestCheckpointExporter:
return
self
.
_checkpoint_manager
return
self
.
_checkpoint_manager
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
):
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
,
write_logs
=
True
)
->
bool
:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
eval_logs
,
global_step
)
eval_logs
,
global_step
)
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
=
eval_logs
self
.
_best_ckpt_logs
=
eval_logs
self
.
_export_best_eval_metric
(
checkpoint
,
self
.
_best_ckpt_logs
,
if
write_logs
:
global_step
)
self
.
export_best_eval_metric
(
self
.
_best_ckpt_logs
,
global_step
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
return
True
return
False
def
_maybe_load_best_eval_metric
(
self
):
def
_maybe_load_best_eval_metric
(
self
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
...
@@ -180,7 +185,7 @@ class BestCheckpointExporter:
...
@@ -180,7 +185,7 @@ class BestCheckpointExporter:
return
True
return
True
return
False
return
False
def
_
export_best_eval_metric
(
self
,
checkpoint
,
eval_logs
,
global_step
):
def
export_best_eval_metric
(
self
,
eval_logs
,
global_step
):
"""Export evaluation results of the best checkpoint into a json file."""
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
...
@@ -190,8 +195,6 @@ class BestCheckpointExporter:
...
@@ -190,8 +195,6 @@ class BestCheckpointExporter:
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
@
property
@
property
def
best_ckpt_logs
(
self
):
def
best_ckpt_logs
(
self
):
return
self
.
_best_ckpt_logs
return
self
.
_best_ckpt_logs
...
@@ -377,11 +380,15 @@ def remove_ckpts(model_dir):
...
@@ -377,11 +380,15 @@ def remove_ckpts(model_dir):
tf
.
io
.
gfile
.
remove
(
file_to_remove
)
tf
.
io
.
gfile
.
remove
(
file_to_remove
)
def
try_count_params
(
model
:
tf
.
keras
.
Model
):
def
try_count_params
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
trainable_only
:
bool
=
False
):
"""Count the number of parameters if model is possible.
"""Count the number of parameters if model is possible.
Args:
Args:
model: Try to count the number of params in this model.
model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns:
Returns:
The number of parameters or None.
The number of parameters or None.
...
@@ -395,7 +402,13 @@ def try_count_params(model: tf.keras.Model):
...
@@ -395,7 +402,13 @@ def try_count_params(model: tf.keras.Model):
'because the model was not feed any input, e.g., the max '
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.'
)
'train step already reached before this run.'
)
return
None
return
None
return
None
else
:
total_params
=
0
variables
=
model
.
trainable_variables
if
trainable_only
else
model
.
variables
for
var
in
variables
:
shape
=
tf
.
shape
(
var
)
total_params
+=
tf
.
math
.
reduce_prod
(
shape
).
numpy
()
return
total_params
def
try_count_flops
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
def
try_count_flops
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
...
...
official/modeling/multitask/configs.py
View file @
cf80ed4e
...
@@ -23,6 +23,7 @@ from official.modeling import hyperparams
...
@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TaskRoutine
(
hyperparams
.
Config
):
class
TaskRoutine
(
hyperparams
.
Config
):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name
:
str
=
""
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
task_config
:
cfg
.
TaskConfig
=
None
eval_steps
:
Optional
[
int
]
=
None
eval_steps
:
Optional
[
int
]
=
None
...
@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
...
@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes:
Attributes:
eval_tasks: individual evaluation tasks.
eval_tasks: individual evaluation tasks.
"""
"""
eval_tasks
:
MultiTaskConfig
=
MultiTaskConfig
()
eval_tasks
:
Tuple
[
TaskRoutine
,
...]
=
()
official/modeling/multitask/evaluator.py
View file @
cf80ed4e
...
@@ -16,14 +16,14 @@
...
@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
"""
from
typing
import
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
gin
import
gin
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
@
gin
.
configurable
...
@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def
__init__
(
def
__init__
(
self
,
self
,
task
:
multitask
.
Multi
Task
,
eval_
task
s
:
List
[
base_task
.
Task
]
,
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
"""Initialize common trainer for TensorFlow models.
"""Initialize common trainer for TensorFlow models.
Args:
Args:
task: A
multitask.MultiTask instanc
e.
eval_
task
s
: A
list of tasks to evaluat
e.
model: tf.keras.Model instance.
model: tf.keras.Model instance.
global_step: the global step variable.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
interface.
"""
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_task
=
task
self
.
_task
s
=
eval_
task
s
self
.
_model
=
model
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
global_step
=
self
.
global_step
,
model
=
self
.
model
)
model
=
self
.
model
)
self
.
_validation_losses
=
None
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
self
.
_validation_metrics
=
None
# Builds per-task datasets.
# Builds per-task datasets.
self
.
eval_datasets
=
{}
self
.
eval_datasets
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
eval_steps
=
eval_steps
or
{}
self
.
eval_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
for
task
in
self
.
tasks
:
self
.
eval_datasets
[
task
.
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
# Builds per-task validation loops.
# Builds per-task validation loops.
...
@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
self
.
task_fns
=
{
self
.
task_fns
=
{
name
:
get_function
(
name
,
task
)
task
.
name
:
get_function
(
task
.
name
,
task
)
for
task
in
self
.
tasks
for
name
,
task
in
self
.
task
.
tasks
.
items
()
}
}
@
property
@
property
...
@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return
self
.
_strategy
return
self
.
_strategy
@
property
@
property
def
task
(
self
):
def
task
s
(
self
):
return
self
.
_task
return
self
.
_task
s
@
property
@
property
def
model
(
self
):
def
model
(
self
):
...
@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if
self
.
_validation_losses
is
None
:
if
self
.
_validation_losses
is
None
:
# Builds the per-task metrics and losses.
# Builds the per-task metrics and losses.
self
.
_validation_losses
=
{}
self
.
_validation_losses
=
{}
for
name
in
self
.
task
.
tasks
:
for
task
in
self
.
tasks
:
self
.
_validation_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
self
.
_validation_losses
[
task
.
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
"validation_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_validation_losses
return
self
.
_validation_losses
...
@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if
self
.
_validation_metrics
is
None
:
if
self
.
_validation_metrics
is
None
:
# Builds the per-task metrics and losses.
# Builds the per-task metrics and losses.
self
.
_validation_metrics
=
{}
self
.
_validation_metrics
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
()
:
for
task
in
self
.
task
s
:
self
.
_validation_metrics
[
name
]
=
task
.
build_metrics
(
training
=
False
)
self
.
_validation_metrics
[
task
.
name
]
=
task
.
build_metrics
(
training
=
False
)
return
self
.
_validation_metrics
return
self
.
_validation_metrics
@
property
@
property
...
@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results
=
{}
results
=
{}
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
for
name
,
task_eval_loop
in
self
.
task
_fns
.
items
()
:
for
task
in
self
.
task
s
:
outputs
=
None
outputs
=
None
name
=
task
.
name
eval_iter
=
eval_iters
[
name
]
eval_iter
=
eval_iters
[
name
]
task
=
self
.
task
.
tasks
[
name
]
task_eval_steps
=
self
.
eval_steps
.
get
(
name
,
None
)
or
num_steps
task_eval_steps
=
self
.
task
.
task_eval_steps
(
name
)
or
num_steps
outputs
=
self
.
task_fns
[
name
](
outputs
=
task_eval_loop
(
eval_iter
,
eval_iter
,
task_eval_steps
,
task_eval_steps
,
state
=
outputs
,
state
=
outputs
,
...
...
official/modeling/multitask/evaluator_test.py
View file @
cf80ed4e
...
@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
...
@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
multitask
def
all_strategy_combinations
():
def
all_strategy_combinations
():
...
@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
...
@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
return
state
return
state
def
reduce_aggregated_logs
(
self
,
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
aggregated_logs
,
global_step
=
None
):
for
k
,
v
in
aggregated_logs
.
items
():
for
k
,
v
in
aggregated_logs
.
items
():
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
return
aggregated_logs
return
aggregated_logs
...
@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multi
task
,
model
=
model
)
eval_tasks
=
task
s
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
...
@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multi
task
,
model
=
model
)
eval_tasks
=
task
s
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
5.
*
distribution
.
num_replicas_in_sync
)
...
...
official/modeling/multitask/interleaving_trainer.py
View file @
cf80ed4e
...
@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
...
@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
optimizer
:
tf
.
optimizers
.
Optimizer
,
optimizer
:
tf
.
optimizers
.
Optimizer
,
task_sampler
:
sampler
.
TaskSampler
,
task_sampler
:
sampler
.
TaskSampler
,
trainer_options
=
None
):
trainer_options
=
None
):
super
(
MultiTaskInterleavingTrainer
,
self
).
__init__
(
super
().
__init__
(
multi_task
=
multi_task
,
multi_task
=
multi_task
,
multi_task_model
=
multi_task_model
,
multi_task_model
=
multi_task_model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
...
@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
...
@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
global_step
.
assign_add
(
1
)
self
.
global_step
.
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
super
().
train_loop_end
()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if
'total_loss'
in
result
:
result
.
pop
(
'total_loss'
)
return
result
official/modeling/multitask/interleaving_trainer_test.py
View file @
cf80ed4e
...
@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
results
[
"bar"
].
keys
())
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
results
[
"foo"
].
keys
())
self
.
assertNotIn
(
"total_loss"
,
results
)
@
combinations
.
generate
(
all_strategy_combinations
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_with_configs
(
self
,
distribution
):
def
test_trainer_with_configs
(
self
,
distribution
):
...
...
official/modeling/multitask/multitask.py
View file @
cf80ed4e
...
@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else
:
else
:
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
type
(
tasks
))
type
(
tasks
))
self
.
_task_eval_steps
=
task_eval_steps
or
{}
self
.
task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_eval_steps
=
dict
([
(
name
,
self
.
_task_eval_steps
.
get
(
name
,
None
))
for
name
in
self
.
tasks
])
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
dict
([
self
.
_task_weights
=
dict
([
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
...
@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps
=
{}
task_eval_steps
=
{}
task_weights
=
{}
task_weights
=
{}
for
task_routine
in
config
.
task_routines
:
for
task_routine
in
config
.
task_routines
:
task_name
=
task_routine
.
task_name
task_name
=
task_routine
.
task_name
or
task_routine
.
task_config
.
name
tasks
[
task_name
]
=
task_factory
.
get_task
(
tasks
[
task_name
]
=
task_factory
.
get_task
(
task_routine
.
task_config
,
logging_dir
=
logging_dir
)
task_routine
.
task_config
,
logging_dir
=
logging_dir
,
name
=
task_name
)
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
return
cls
(
...
@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def
tasks
(
self
):
def
tasks
(
self
):
return
self
.
_tasks
return
self
.
_tasks
def
task_eval_steps
(
self
,
task_name
):
return
self
.
_task_eval_steps
[
task_name
]
def
task_weight
(
self
,
task_name
):
def
task_weight
(
self
,
task_name
):
return
self
.
_task_weights
[
task_name
]
return
self
.
_task_weights
[
task_name
]
...
...
official/modeling/multitask/train_lib.py
View file @
cf80ed4e
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Multitask training driver library."""
"""Multitask training driver library."""
# pytype: disable=attribute-error
# pytype: disable=attribute-error
import
os
import
os
from
typing
import
Optional
from
typing
import
List
,
Optional
from
absl
import
logging
from
absl
import
logging
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
...
@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer
=
TRAINERS
[
params
.
trainer
.
trainer_type
](
trainer
=
TRAINERS
[
params
.
trainer
.
trainer_type
](
**
kwargs
)
if
is_training
else
None
**
kwargs
)
if
is_training
else
None
if
is_eval
:
if
is_eval
:
eval_steps
=
task
.
task_eval_steps
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
task
,
eval_
task
s
=
task
.
tasks
.
values
()
,
model
=
model
,
model
=
model
,
eval_steps
=
eval_steps
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
params
,
model_dir
))
...
@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
...
@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*
,
*
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
train_task
:
base_task
.
Task
,
train_task
:
base_task
.
Task
,
eval_tasks
:
multitask
.
Multi
Task
,
eval_tasks
:
List
[
base_task
.
Task
]
,
mode
:
str
,
mode
:
str
,
params
:
configs
.
MultiEvalExperimentConfig
,
params
:
configs
.
MultiEvalExperimentConfig
,
model_dir
:
str
,
model_dir
:
str
,
...
@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
...
@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args:
Args:
distribution_strategy: A distribution distribution_strategy.
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
train_task: A base_task.Task instance.
eval_tasks: A
multitask.MultiTask with
evaluation tasks.
eval_tasks: A
list of
evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
params: MultiEvalExperimentConfig instance.
...
@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
...
@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config
=
params
,
config
=
params
,
task
=
train_task
,
task
=
train_task
,
model
=
train_task
.
build_model
(),
model
=
train_task
.
build_model
(),
optimizer
=
train_task
.
create_optimizer
(
optimizer
=
train_task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
trainer
.
optimizer_config
,
params
.
runtime
),
params
.
runtime
),
train
=
True
,
train
=
True
,
evaluate
=
False
)
evaluate
=
False
)
else
:
else
:
...
@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
...
@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model
=
trainer
.
model
if
trainer
else
train_task
.
build_model
()
model
=
trainer
.
model
if
trainer
else
train_task
.
build_model
()
if
is_eval
:
if
is_eval
:
eval_steps
=
dict
([(
task_routine
.
task_config
.
name
,
task_routine
.
eval_steps
)
for
task_routine
in
params
.
eval_tasks
])
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
eval_tasks
,
eval_
task
s
=
eval_tasks
,
model
=
model
,
model
=
model
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
eval_steps
=
eval_steps
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
params
,
model_dir
))
else
:
else
:
...
...
official/modeling/multitask/train_lib_test.py
View file @
cf80ed4e
...
@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task
=
configs
.
MultiTaskConfig
(
task
=
configs
.
MultiTaskConfig
(
task_routines
=
(
task_routines
=
(
configs
.
TaskRoutine
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
task_config
=
test_utils
.
FooConfig
()),
configs
.
TaskRoutine
(
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
=
params_dict
.
override_params_dict
(
...
@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir
=
self
.
get_temp_dir
()
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
configs
.
MultiEvalExperimentConfig
(
experiment_config
=
configs
.
MultiEvalExperimentConfig
(
task
=
test_utils
.
FooConfig
(),
task
=
test_utils
.
FooConfig
(),
eval_tasks
=
configs
.
MultiTaskConfig
(
eval_tasks
=
(
configs
.
TaskRoutine
(
task_routines
=
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
(),
eval_steps
=
2
),
configs
.
TaskRoutine
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_name
=
'bar'
,
task_config
=
test_utils
.
FooConfig
()),
task_config
=
test_utils
.
BarConfig
(),
configs
.
TaskRoutine
(
eval_steps
=
3
)))
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
train_task
=
task_factory
.
get_task
(
experiment_config
.
task
)
train_task
=
task_factory
.
get_task
(
experiment_config
.
task
)
eval_tasks
=
multitask
.
MultiTask
.
from_config
(
experiment_config
.
eval_tasks
)
eval_tasks
=
[
task_factory
.
get_task
(
config
.
task_config
,
name
=
config
.
task_name
)
for
config
in
experiment_config
.
eval_tasks
]
train_lib
.
run_experiment_with_multitask_eval
(
train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
train_task
=
train_task
,
train_task
=
train_task
,
...
...
official/nlp/continuous_finetune_lib.py
View file @
cf80ed4e
...
@@ -28,7 +28,6 @@ from official.core import train_lib
...
@@ -28,7 +28,6 @@ from official.core import train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
train_lib
as
multitask_train_lib
from
official.modeling.multitask
import
train_lib
as
multitask_train_lib
...
@@ -167,7 +166,10 @@ def run_continuous_finetune(
...
@@ -167,7 +166,10 @@ def run_continuous_finetune(
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
if
isinstance
(
params
,
configs
.
MultiEvalExperimentConfig
):
if
isinstance
(
params
,
configs
.
MultiEvalExperimentConfig
):
task
=
task_factory
.
get_task
(
params_replaced
.
task
)
task
=
task_factory
.
get_task
(
params_replaced
.
task
)
eval_tasks
=
multitask
.
MultiTask
.
from_config
(
params_replaced
.
eval_tasks
)
eval_tasks
=
[
task_factory
.
get_task
(
config
.
task_config
,
name
=
config
.
task_name
)
for
config
in
params
.
eval_tasks
]
(
_
,
(
_
,
eval_metrics
)
=
multitask_train_lib
.
run_experiment_with_multitask_eval
(
eval_metrics
)
=
multitask_train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
...
...
official/nlp/metrics/bleu.py
View file @
cf80ed4e
...
@@ -89,8 +89,7 @@ def _get_ngrams_with_counter(segment, max_order):
...
@@ -89,8 +89,7 @@ def _get_ngrams_with_counter(segment, max_order):
Args:
Args:
segment: text segment from which n-grams will be extracted.
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
max_order: maximum length in tokens of the n-grams returned by this methods.
methods.
Returns:
Returns:
The Counter containing all n-grams upto max_order in segment
The Counter containing all n-grams upto max_order in segment
...
@@ -104,15 +103,17 @@ def _get_ngrams_with_counter(segment, max_order):
...
@@ -104,15 +103,17 @@ def _get_ngrams_with_counter(segment, max_order):
return
ngram_counts
return
ngram_counts
def
compute_bleu
(
reference_corpus
,
translation_corpus
,
max_order
=
4
,
def
compute_bleu
(
reference_corpus
,
translation_corpus
,
max_order
=
4
,
use_bp
=
True
):
use_bp
=
True
):
"""Computes BLEU score of translated segments against one or more references.
"""Computes BLEU score of translated segments against one or more references.
Args:
Args:
reference_corpus: list of references for each translation. Each
reference_corpus: list of references for each translation. Each
reference
reference
should be tokenized into a list of tokens.
should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
translation_corpus: list of translations to score. Each translation
should
should
be tokenized into a list of tokens.
be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
use_bp: boolean, whether to apply brevity penalty.
...
@@ -134,15 +135,14 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
...
@@ -134,15 +135,14 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
ref_ngram_counts
=
_get_ngrams_with_counter
(
references
,
max_order
)
ref_ngram_counts
=
_get_ngrams_with_counter
(
references
,
max_order
)
translation_ngram_counts
=
_get_ngrams_with_counter
(
translations
,
max_order
)
translation_ngram_counts
=
_get_ngrams_with_counter
(
translations
,
max_order
)
overlap
=
dict
((
ngram
,
overlap
=
dict
((
ngram
,
min
(
count
,
translation_ngram_counts
[
ngram
]))
min
(
count
,
translation_ngram_counts
[
ngram
]))
for
ngram
,
count
in
ref_ngram_counts
.
items
())
for
ngram
,
count
in
ref_ngram_counts
.
items
())
for
ngram
in
overlap
:
for
ngram
in
overlap
:
matches_by_order
[
len
(
ngram
)
-
1
]
+=
overlap
[
ngram
]
matches_by_order
[
len
(
ngram
)
-
1
]
+=
overlap
[
ngram
]
for
ngram
in
translation_ngram_counts
:
for
ngram
in
translation_ngram_counts
:
possible_matches_by_order
[
len
(
ngram
)
-
1
]
+=
translation_ngram_counts
[
possible_matches_by_order
[
len
(
ngram
)
-
ngram
]
1
]
+=
translation_ngram_counts
[
ngram
]
precisions
=
[
0
]
*
max_order
precisions
=
[
0
]
*
max_order
smooth
=
1.0
smooth
=
1.0
...
@@ -151,8 +151,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
...
@@ -151,8 +151,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
if
possible_matches_by_order
[
i
]
>
0
:
if
possible_matches_by_order
[
i
]
>
0
:
precisions
[
i
]
=
float
(
matches_by_order
[
i
])
/
possible_matches_by_order
[
i
]
precisions
[
i
]
=
float
(
matches_by_order
[
i
])
/
possible_matches_by_order
[
i
]
if
matches_by_order
[
i
]
>
0
:
if
matches_by_order
[
i
]
>
0
:
precisions
[
i
]
=
float
(
matches_by_order
[
i
])
/
possible_matches_by_order
[
precisions
[
i
]
=
float
(
i
]
matches_by_order
[
i
])
/
possible_matches_by_order
[
i
]
else
:
else
:
smooth
*=
2
smooth
*=
2
precisions
[
i
]
=
1.0
/
(
smooth
*
possible_matches_by_order
[
i
])
precisions
[
i
]
=
1.0
/
(
smooth
*
possible_matches_by_order
[
i
])
...
@@ -165,7 +165,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
...
@@ -165,7 +165,8 @@ def compute_bleu(reference_corpus, translation_corpus, max_order=4,
if
use_bp
:
if
use_bp
:
ratio
=
translation_length
/
reference_length
ratio
=
translation_length
/
reference_length
bp
=
math
.
exp
(
1
-
1.
/
ratio
)
if
ratio
<
1.0
else
1.0
bp
=
0.
if
ratio
<
1e-6
else
math
.
exp
(
1
-
1.
/
ratio
)
if
ratio
<
1.0
else
1.0
bleu
=
geo_mean
*
bp
bleu
=
geo_mean
*
bp
return
np
.
float32
(
bleu
)
return
np
.
float32
(
bleu
)
...
...
official/nlp/modeling/__init__.py
View file @
cf80ed4e
...
@@ -22,3 +22,4 @@ from official.nlp.modeling import layers
...
@@ -22,3 +22,4 @@ from official.nlp.modeling import layers
from
official.nlp.modeling
import
losses
from
official.nlp.modeling
import
losses
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
ops
official/nlp/modeling/layers/mobile_bert_layers.py
View file @
cf80ed4e
...
@@ -39,6 +39,23 @@ class NoNorm(tf.keras.layers.Layer):
...
@@ -39,6 +39,23 @@ class NoNorm(tf.keras.layers.Layer):
return
output
return
output
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
NoNormClipped
(
NoNorm
):
"""Quantization friendly implementation for the NoNorm.
The output of NoNorm layer is clipped to [-6.0, 6.0] to make it quantization
friendly.
"""
def
__init__
(
self
,
name
=
None
):
super
(
NoNormClipped
,
self
).
__init__
(
name
=
name
)
def
call
(
self
,
feature
):
output
=
feature
*
self
.
scale
+
self
.
bias
clipped_output
=
tf
.
clip_by_value
(
output
,
-
6.0
,
6.0
)
return
clipped_output
def
_get_norm_layer
(
normalization_type
=
'no_norm'
,
name
=
None
):
def
_get_norm_layer
(
normalization_type
=
'no_norm'
,
name
=
None
):
"""Get normlization layer.
"""Get normlization layer.
...
@@ -52,6 +69,8 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
...
@@ -52,6 +69,8 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
"""
"""
if
normalization_type
==
'no_norm'
:
if
normalization_type
==
'no_norm'
:
layer
=
NoNorm
(
name
=
name
)
layer
=
NoNorm
(
name
=
name
)
elif
normalization_type
==
'no_norm_clipped'
:
layer
=
NoNormClipped
(
name
=
name
)
elif
normalization_type
==
'layer_norm'
:
elif
normalization_type
==
'layer_norm'
:
layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
name
,
name
=
name
,
...
...
official/nlp/modeling/layers/mobile_bert_layers_test.py
View file @
cf80ed4e
...
@@ -33,6 +33,22 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
...
@@ -33,6 +33,22 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return
fake_input
return
fake_input
class
EdgeTPUNoNormTest
(
tf
.
test
.
TestCase
):
def
test_no_norm
(
self
):
layer
=
mobile_bert_layers
.
NoNormClipped
()
feature
=
tf
.
random
.
uniform
(
[
2
,
3
,
4
],
minval
=-
8
,
maxval
=
8
,
dtype
=
tf
.
float32
)
output
=
layer
(
feature
)
output_shape
=
output
.
shape
.
as_list
()
expected_shape
=
[
2
,
3
,
4
]
self
.
assertListEqual
(
output_shape
,
expected_shape
,
msg
=
None
)
output_min
=
tf
.
reduce_min
(
output
)
output_max
=
tf
.
reduce_max
(
output
)
self
.
assertGreaterEqual
(
6.0
,
output_max
)
self
.
assertLessEqual
(
-
6.0
,
output_min
)
class
MobileBertEncoderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
MobileBertEncoderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_embedding_layer_with_token_type
(
self
):
def
test_embedding_layer_with_token_type
(
self
):
...
...
official/nlp/modeling/layers/spectral_normalization.py
View file @
cf80ed4e
...
@@ -106,16 +106,19 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
...
@@ -106,16 +106,19 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
def
call
(
self
,
inputs
,
*
,
training
=
None
):
def
call
(
self
,
inputs
,
*
,
training
=
None
):
training
=
self
.
do_power_iteration
if
training
is
None
else
training
training
=
self
.
do_power_iteration
if
training
is
None
else
training
u_update_op
,
v_update_op
,
w_update_op
=
self
.
update_weights
(
if
training
:
training
=
training
)
u_update_op
,
v_update_op
,
w_update_op
=
self
.
update_weights
(
output
=
self
.
layer
(
inputs
)
training
=
training
)
w_restore_op
=
self
.
restore_weights
()
output
=
self
.
layer
(
inputs
)
w_restore_op
=
self
.
restore_weights
()
# Register update ops.
self
.
add_update
(
u_update_op
)
# Register update ops.
self
.
add_update
(
v_update_op
)
self
.
add_update
(
u_update_op
)
self
.
add_update
(
w_update_op
)
self
.
add_update
(
v_update_op
)
self
.
add_update
(
w_restore_op
)
self
.
add_update
(
w_update_op
)
self
.
add_update
(
w_restore_op
)
else
:
output
=
self
.
layer
(
inputs
)
return
output
return
output
...
...
official/nlp/modeling/models/bert_classifier.py
View file @
cf80ed4e
...
@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
...
@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler'
, 'head_name'
) will be ignored.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
...
@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer
=
'glorot_uniform'
,
initializer
=
'glorot_uniform'
,
dropout_rate
=
0.1
,
dropout_rate
=
0.1
,
use_encoder_pooler
=
True
,
use_encoder_pooler
=
True
,
head_name
=
'sentence_prediction'
,
cls_head
=
None
,
cls_head
=
None
,
**
kwargs
):
**
kwargs
):
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
head_name
=
head_name
self
.
initializer
=
initializer
self
.
initializer
=
initializer
self
.
use_encoder_pooler
=
use_encoder_pooler
self
.
use_encoder_pooler
=
use_encoder_pooler
...
@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
...
@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
initializer
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
name
=
'sentence_prediction'
)
name
=
head_name
)
predictions
=
classifier
(
cls_inputs
)
predictions
=
classifier
(
cls_inputs
)
...
@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
...
@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return
{
return
{
'network'
:
self
.
_network
,
'network'
:
self
.
_network
,
'num_classes'
:
self
.
num_classes
,
'num_classes'
:
self
.
num_classes
,
'head_name'
:
self
.
head_name
,
'initializer'
:
self
.
initializer
,
'initializer'
:
self
.
initializer
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'cls_head'
:
self
.
_cls_head
,
'cls_head'
:
self
.
_cls_head
,
...
...
official/nlp/modeling/models/bert_classifier_test.py
View file @
cf80ed4e
...
@@ -87,10 +87,8 @@ class BertClassifierTest(keras_parameterized.TestCase):
...
@@ -87,10 +87,8 @@ class BertClassifierTest(keras_parameterized.TestCase):
inner_dim
=
0
,
num_classes
=
4
)))
inner_dim
=
0
,
num_classes
=
4
)))
def
test_serialize_deserialize
(
self
,
cls_head
):
def
test_serialize_deserialize
(
self
,
cls_head
):
"""Validate that the BERT trainer can be serialized and deserialized."""
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
# are different, so we can catch any serialization mismatches.)
...
...
official/nlp/modeling/models/bert_pretrainer_test.py
View file @
cf80ed4e
...
@@ -67,10 +67,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
...
@@ -67,10 +67,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
def
test_bert_trainer_tensor_call
(
self
):
def
test_bert_trainer_tensor_call
(
self
):
"""Validate that the Keras object can be invoked."""
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
2
)
# Create a BERT trainer with the created network.
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_pretrainer
.
BertPretrainer
(
bert_trainer_model
=
bert_pretrainer
.
BertPretrainer
(
...
@@ -213,10 +211,8 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
...
@@ -213,10 +211,8 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
def
test_v2_serialize_deserialize
(
self
):
def
test_v2_serialize_deserialize
(
self
):
"""Validate that the BERT trainer can be serialized and deserialized."""
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
# are different, so we can catch any serialization mismatches.)
...
...
official/nlp/modeling/models/bert_span_labeler_test.py
View file @
cf80ed4e
...
@@ -93,10 +93,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
...
@@ -93,10 +93,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
def
test_serialize_deserialize
(
self
):
def
test_serialize_deserialize
(
self
):
"""Validate that the BERT trainer can be serialized and deserialized."""
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer.
# a short sequence_length for convenience.)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
# are different, so we can catch any serialization mismatches.)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment