Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a04d9e0e
Commit
a04d9e0e
authored
Jun 14, 2021
by
Vishnu Banna
Browse files
merged
parents
64f16d61
bcbce005
Changes
120
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
241 additions
and
76 deletions
+241
-76
CODEOWNERS
CODEOWNERS
+1
-0
official/__init__.py
official/__init__.py
+14
-0
official/core/base_task.py
official/core/base_task.py
+33
-3
official/core/base_trainer.py
official/core/base_trainer.py
+3
-2
official/modeling/activations/sigmoid.py
official/modeling/activations/sigmoid.py
+1
-1
official/modeling/activations/swish.py
official/modeling/activations/swish.py
+2
-1
official/modeling/optimization/configs/optimization_config.py
...cial/modeling/optimization/configs/optimization_config.py
+2
-0
official/modeling/optimization/configs/optimizer_config.py
official/modeling/optimization/configs/optimizer_config.py
+21
-0
official/modeling/optimization/ema_optimizer.py
official/modeling/optimization/ema_optimizer.py
+2
-2
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+5
-3
official/modeling/optimization/slide_optimizer.py
official/modeling/optimization/slide_optimizer.py
+20
-0
official/modeling/progressive/trainer.py
official/modeling/progressive/trainer.py
+4
-1
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+16
-18
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+19
-6
official/nlp/data/sentence_prediction_dataloader_test.py
official/nlp/data/sentence_prediction_dataloader_test.py
+49
-17
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+2
-2
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+8
-8
official/nlp/projects/mobilebert/README.md
official/nlp/projects/mobilebert/README.md
+1
-1
official/nlp/projects/teams/README.md
official/nlp/projects/teams/README.md
+21
-0
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+17
-11
No files found.
CODEOWNERS
View file @
a04d9e0e
* @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team
* @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team
/official/ @rachellj218 @saberkun @jaeyounkim
/official/ @rachellj218 @saberkun @jaeyounkim
/official/nlp/ @saberkun @lehougoogle @rachellj218 @jaeyounkim
/official/nlp/ @saberkun @lehougoogle @rachellj218 @jaeyounkim
/official/recommendation/ranking/ @gagika
/official/vision/ @xianzhidu @yeqingli @arashwan @saberkun @rachellj218 @jaeyounkim
/official/vision/ @xianzhidu @yeqingli @arashwan @saberkun @rachellj218 @jaeyounkim
/official/vision/beta/projects/assemblenet/ @mryoo
/official/vision/beta/projects/assemblenet/ @mryoo
/official/vision/beta/projects/deepmac_maskrcnn/ @vighneshbirodkar
/official/vision/beta/projects/deepmac_maskrcnn/ @vighneshbirodkar
...
...
official/__init__.py
View file @
a04d9e0e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/core/base_task.py
View file @
a04d9e0e
...
@@ -38,7 +38,10 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
...
@@ -38,7 +38,10 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# Special keys in train/validate step returned logs.
# Special keys in train/validate step returned logs.
loss
=
"loss"
loss
=
"loss"
def
__init__
(
self
,
params
,
logging_dir
:
str
=
None
,
name
:
str
=
None
):
def
__init__
(
self
,
params
,
logging_dir
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
):
"""Task initialization.
"""Task initialization.
Args:
Args:
...
@@ -294,11 +297,38 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
...
@@ -294,11 +297,38 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
return
model
(
inputs
,
training
=
False
)
return
model
(
inputs
,
training
=
False
)
def
aggregate_logs
(
self
,
state
,
step_logs
):
def
aggregate_logs
(
self
,
state
,
step_logs
):
"""Optional aggregation over logs returned from a validation step."""
"""Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs
after each eval_step() (see eval_reduce() function in
official/core/base_trainer.py). It runs on CPU and can be used to aggregate
metrics during validation, when there are too many metrics that cannot fit
into TPU memory. Note that this may increase latency due to data transfer
between TPU and CPU. Also, the step output from a validation step may be a
tuple with elements from replicas, and a concatenation of the elements is
needed in such case.
Args:
state: The current state of training, for example, it can be a sequence of
metrics.
step_logs: Logs from a validation step. Can be a dictionary.
"""
pass
pass
def
reduce_aggregated_logs
(
self
,
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
aggregated_logs
,
global_step
:
Optional
[
tf
.
Tensor
]
=
None
):
global_step
:
Optional
[
tf
.
Tensor
]
=
None
):
"""Optional reduce of aggregated logs over validation steps."""
"""Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be
used to compute the final metrics. It runs on CPU and in each eval_end() in
base trainer (see eval_end() function in official/core/base_trainer.py).
Args:
aggregated_logs: Aggregated logs over multiple validation steps.
global_step: An optional variable of global step.
Returns:
A dictionary of reduced results.
"""
return
{}
return
{}
official/core/base_trainer.py
View file @
a04d9e0e
...
@@ -246,10 +246,11 @@ class Trainer(_AsyncTrainer):
...
@@ -246,10 +246,11 @@ class Trainer(_AsyncTrainer):
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
"validation_loss"
,
dtype
=
tf
.
float32
)
model_metrics
=
model
.
metrics
if
hasattr
(
model
,
"metrics"
)
else
[]
self
.
_train_metrics
=
self
.
task
.
build_metrics
(
self
.
_train_metrics
=
self
.
task
.
build_metrics
(
training
=
True
)
+
self
.
model
.
metrics
training
=
True
)
+
model
_
metrics
self
.
_validation_metrics
=
self
.
task
.
build_metrics
(
self
.
_validation_metrics
=
self
.
task
.
build_metrics
(
training
=
False
)
+
self
.
model
.
metrics
training
=
False
)
+
model
_
metrics
self
.
init_async
()
self
.
init_async
()
...
...
official/modeling/activations/sigmoid.py
View file @
a04d9e0e
...
@@ -28,4 +28,4 @@ def hard_sigmoid(features):
...
@@ -28,4 +28,4 @@ def hard_sigmoid(features):
The activation value.
The activation value.
"""
"""
features
=
tf
.
convert_to_tensor
(
features
)
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
nn
.
relu6
(
features
+
tf
.
c
onstant
(
3.
))
*
0.16667
return
tf
.
nn
.
relu6
(
features
+
tf
.
c
ast
(
3.
,
features
.
dtype
))
*
0.16667
official/modeling/activations/swish.py
View file @
a04d9e0e
...
@@ -52,7 +52,8 @@ def hard_swish(features):
...
@@ -52,7 +52,8 @@ def hard_swish(features):
The activation value.
The activation value.
"""
"""
features
=
tf
.
convert_to_tensor
(
features
)
features
=
tf
.
convert_to_tensor
(
features
)
return
features
*
tf
.
nn
.
relu6
(
features
+
tf
.
constant
(
3.
))
*
(
1.
/
6.
)
fdtype
=
features
.
dtype
return
features
*
tf
.
nn
.
relu6
(
features
+
tf
.
cast
(
3.
,
fdtype
))
*
(
1.
/
6.
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
...
...
official/modeling/optimization/configs/optimization_config.py
View file @
a04d9e0e
...
@@ -41,6 +41,7 @@ class OptimizerConfig(oneof.OneOfConfig):
...
@@ -41,6 +41,7 @@ class OptimizerConfig(oneof.OneOfConfig):
rmsprop: rmsprop optimizer.
rmsprop: rmsprop optimizer.
lars: lars optimizer.
lars: lars optimizer.
adagrad: adagrad optimizer.
adagrad: adagrad optimizer.
slide: slide optimizer.
"""
"""
type
:
Optional
[
str
]
=
None
type
:
Optional
[
str
]
=
None
sgd
:
opt_cfg
.
SGDConfig
=
opt_cfg
.
SGDConfig
()
sgd
:
opt_cfg
.
SGDConfig
=
opt_cfg
.
SGDConfig
()
...
@@ -50,6 +51,7 @@ class OptimizerConfig(oneof.OneOfConfig):
...
@@ -50,6 +51,7 @@ class OptimizerConfig(oneof.OneOfConfig):
rmsprop
:
opt_cfg
.
RMSPropConfig
=
opt_cfg
.
RMSPropConfig
()
rmsprop
:
opt_cfg
.
RMSPropConfig
=
opt_cfg
.
RMSPropConfig
()
lars
:
opt_cfg
.
LARSConfig
=
opt_cfg
.
LARSConfig
()
lars
:
opt_cfg
.
LARSConfig
=
opt_cfg
.
LARSConfig
()
adagrad
:
opt_cfg
.
AdagradConfig
=
opt_cfg
.
AdagradConfig
()
adagrad
:
opt_cfg
.
AdagradConfig
=
opt_cfg
.
AdagradConfig
()
slide
:
opt_cfg
.
SLIDEConfig
=
opt_cfg
.
SLIDEConfig
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/modeling/optimization/configs/optimizer_config.py
View file @
a04d9e0e
...
@@ -226,3 +226,24 @@ class LARSConfig(BaseOptimizerConfig):
...
@@ -226,3 +226,24 @@ class LARSConfig(BaseOptimizerConfig):
classic_momentum
:
bool
=
True
classic_momentum
:
bool
=
True
exclude_from_weight_decay
:
Optional
[
List
[
str
]]
=
None
exclude_from_weight_decay
:
Optional
[
List
[
str
]]
=
None
exclude_from_layer_adaptation
:
Optional
[
List
[
str
]]
=
None
exclude_from_layer_adaptation
:
Optional
[
List
[
str
]]
=
None
@
dataclasses
.
dataclass
class
SLIDEConfig
(
BaseOptimizerConfig
):
"""Configuration for SLIDE optimizer.
Details coming soon.
"""
name
:
str
=
"SLIDE"
beta_1
:
float
=
0.9
beta_2
:
float
=
0.999
epsilon
:
float
=
1e-6
weight_decay_rate
:
float
=
0.0
weight_decay_type
:
str
=
"inner"
exclude_from_weight_decay
:
Optional
[
List
[
str
]]
=
None
exclude_from_layer_adaptation
:
Optional
[
List
[
str
]]
=
None
include_in_sparse_layer_adaptation
:
Optional
[
List
[
str
]]
=
None
sparse_layer_learning_rate
:
float
=
0.1
do_gradient_rescaling
:
bool
=
True
norm_type
:
str
=
"layer"
ratio_clip_norm
:
float
=
1e5
official/modeling/optimization/ema_optimizer.py
View file @
a04d9e0e
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
"""Exponential moving average optimizer."""
"""Exponential moving average optimizer."""
from
typing
import
Text
,
Lis
t
from
typing
import
List
,
Optional
,
Tex
t
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -106,7 +106,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
...
@@ -106,7 +106,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
def
_create_slots
(
self
,
var_list
):
def
_create_slots
(
self
,
var_list
):
self
.
_optimizer
.
_create_slots
(
var_list
=
var_list
)
# pylint: disable=protected-access
self
.
_optimizer
.
_create_slots
(
var_list
=
var_list
)
# pylint: disable=protected-access
def
apply_gradients
(
self
,
grads_and_vars
,
name
:
Text
=
None
):
def
apply_gradients
(
self
,
grads_and_vars
,
name
:
Optional
[
Text
]
=
None
):
result
=
self
.
_optimizer
.
apply_gradients
(
grads_and_vars
,
name
)
result
=
self
.
_optimizer
.
apply_gradients
(
grads_and_vars
,
name
)
self
.
update_average
(
self
.
iterations
)
self
.
update_average
(
self
.
iterations
)
return
result
return
result
...
...
official/modeling/optimization/optimizer_factory.py
View file @
a04d9e0e
...
@@ -13,12 +13,13 @@
...
@@ -13,12 +13,13 @@
# limitations under the License.
# limitations under the License.
"""Optimizer factory class."""
"""Optimizer factory class."""
from
typing
import
Callable
,
Union
from
typing
import
Callable
,
Optional
,
Union
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_addons.optimizers
as
tfa_optimizers
import
tensorflow_addons.optimizers
as
tfa_optimizers
from
official.modeling.optimization
import
slide_optimizer
from
official.modeling.optimization
import
ema_optimizer
from
official.modeling.optimization
import
ema_optimizer
from
official.modeling.optimization
import
lars_optimizer
from
official.modeling.optimization
import
lars_optimizer
from
official.modeling.optimization
import
lr_schedule
from
official.modeling.optimization
import
lr_schedule
...
@@ -33,6 +34,7 @@ OPTIMIZERS_CLS = {
...
@@ -33,6 +34,7 @@ OPTIMIZERS_CLS = {
'rmsprop'
:
tf
.
keras
.
optimizers
.
RMSprop
,
'rmsprop'
:
tf
.
keras
.
optimizers
.
RMSprop
,
'lars'
:
lars_optimizer
.
LARS
,
'lars'
:
lars_optimizer
.
LARS
,
'adagrad'
:
tf
.
keras
.
optimizers
.
Adagrad
,
'adagrad'
:
tf
.
keras
.
optimizers
.
Adagrad
,
'slide'
:
slide_optimizer
.
SLIDE
}
}
LR_CLS
=
{
LR_CLS
=
{
...
@@ -134,8 +136,8 @@ class OptimizerFactory:
...
@@ -134,8 +136,8 @@ class OptimizerFactory:
def
build_optimizer
(
def
build_optimizer
(
self
,
self
,
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
float
],
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
float
],
postprocessor
:
Callable
[[
tf
.
keras
.
optimizers
.
Optimizer
],
postprocessor
:
Optional
[
Callable
[[
tf
.
keras
.
optimizers
.
Optimizer
],
tf
.
keras
.
optimizers
.
Optimizer
]
=
None
):
tf
.
keras
.
optimizers
.
Optimizer
]
]
=
None
):
"""Build optimizer.
"""Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds
Builds optimizer from config. It takes learning rate as input, and builds
...
...
official/modeling/optimization/slide_optimizer.py
0 → 100644
View file @
a04d9e0e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SLIDE optimizer.
A new optimizer that will be open sourced soon.
"""
SLIDE
=
"Unimplemented"
official/modeling/progressive/trainer.py
View file @
a04d9e0e
...
@@ -284,8 +284,11 @@ class ProgressiveTrainer(trainer_lib.Trainer):
...
@@ -284,8 +284,11 @@ class ProgressiveTrainer(trainer_lib.Trainer):
checkpoint_interval
=
checkpoint_interval
,
checkpoint_interval
=
checkpoint_interval
,
)
)
# Make sure we export the last checkpoint.
last_checkpoint
=
(
self
.
global_step
.
numpy
()
==
self
.
_config
.
trainer
.
train_steps
)
checkpoint_path
=
self
.
_export_ckpt_manager
.
save
(
checkpoint_path
=
self
.
_export_ckpt_manager
.
save
(
checkpoint_number
=
self
.
global_step
.
numpy
(),
checkpoint_number
=
self
.
global_step
.
numpy
(),
check_interval
=
True
)
check_interval
=
not
last_checkpoint
)
if
checkpoint_path
:
if
checkpoint_path
:
logging
.
info
(
'Checkpoints exported: %s.'
,
checkpoint_path
)
logging
.
info
(
'Checkpoints exported: %s.'
,
checkpoint_path
)
official/nlp/data/classifier_data_lib.py
View file @
a04d9e0e
...
@@ -181,20 +181,21 @@ class AxProcessor(DataProcessor):
...
@@ -181,20 +181,21 @@ class AxProcessor(DataProcessor):
class
ColaProcessor
(
DataProcessor
):
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
"""Processor for the CoLA data set (GLUE version)."""
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
ColaProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
dataset
=
tfds
.
load
(
"glue/cola"
,
try_gcs
=
True
)
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -205,22 +206,19 @@ class ColaProcessor(DataProcessor):
...
@@ -205,22 +206,19 @@ class ColaProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"COLA"
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
self
.
dataset
[
set_type
].
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
# Only the test set has a header.
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
label
=
"0"
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
example
[
"sentence"
])
label
=
"0"
if
set_type
!=
"test"
:
else
:
label
=
str
(
example
[
"label"
])
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
...
official/nlp/data/sentence_prediction_dataloader.py
View file @
a04d9e0e
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
"""Loads dataset for the sentence prediction (classification) task."""
"""Loads dataset for the sentence prediction (classification) task."""
import
functools
import
functools
from
typing
import
List
,
Mapping
,
Optional
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -40,6 +40,10 @@ class SentencePredictionDataConfig(cfg.DataConfig):
...
@@ -40,6 +40,10 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type
:
str
=
'int'
label_type
:
str
=
'int'
# Whether to include the example id number.
# Whether to include the example id number.
include_example_id
:
bool
=
False
include_example_id
:
bool
=
False
label_field
:
str
=
'label_ids'
# Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels'
label_name
:
Optional
[
Tuple
[
str
,
str
]]
=
None
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
...
@@ -50,6 +54,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
...
@@ -50,6 +54,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self
.
_params
=
params
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_seq_length
=
params
.
seq_length
self
.
_include_example_id
=
params
.
include_example_id
self
.
_include_example_id
=
params
.
include_example_id
self
.
_label_field
=
params
.
label_field
if
params
.
label_name
:
self
.
_label_name_mapping
=
dict
([
params
.
label_name
])
else
:
self
.
_label_name_mapping
=
dict
()
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
...
@@ -58,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
...
@@ -58,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'
label_
ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
self
.
_
label_
field
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
}
if
self
.
_include_example_id
:
if
self
.
_include_example_id
:
name_to_features
[
'example_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'example_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
...
@@ -85,8 +94,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
...
@@ -85,8 +94,12 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if
self
.
_include_example_id
:
if
self
.
_include_example_id
:
x
[
'example_id'
]
=
record
[
'example_id'
]
x
[
'example_id'
]
=
record
[
'example_id'
]
y
=
record
[
'label_ids'
]
x
[
self
.
_label_field
]
=
record
[
self
.
_label_field
]
return
(
x
,
y
)
if
self
.
_label_field
in
self
.
_label_name_mapping
:
x
[
self
.
_label_name_mapping
[
self
.
_label_field
]]
=
record
[
self
.
_label_field
]
return
x
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
"""Returns a tf.dataset.Dataset."""
...
@@ -204,8 +217,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
...
@@ -204,8 +217,8 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs
=
self
.
_text_processor
(
segments
)
model_inputs
=
self
.
_text_processor
(
segments
)
if
self
.
_include_example_id
:
if
self
.
_include_example_id
:
model_inputs
[
'example_id'
]
=
record
[
'example_id'
]
model_inputs
[
'example_id'
]
=
record
[
'example_id'
]
y
=
record
[
self
.
_label_field
]
model_inputs
[
self
.
_label_field
]
=
record
[
self
.
_label_field
]
return
model_inputs
,
y
return
model_inputs
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
...
...
official/nlp/data/sentence_prediction_dataloader_test.py
View file @
a04d9e0e
...
@@ -132,14 +132,40 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -132,14 +132,40 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
global_batch_size
=
batch_size
,
global_batch_size
=
batch_size
,
label_type
=
label_type
)
label_type
=
label_type
)
dataset
=
loader
.
SentencePredictionDataLoader
(
data_config
).
load
()
dataset
=
loader
.
SentencePredictionDataLoader
(
data_config
).
load
()
features
,
labels
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_mask'
,
'input_type_ids'
],
self
.
assertCountEqual
(
features
.
keys
())
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
'label_ids'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
'label_ids'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
labels
.
dtype
,
expected_label_type
)
self
.
assertEqual
(
features
[
'label_ids'
].
dtype
,
expected_label_type
)
def
test_load_dataset_with_label_mapping
(
self
):
input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
batch_size
=
10
seq_length
=
128
_create_fake_preprocessed_dataset
(
input_path
,
seq_length
,
'int'
)
data_config
=
loader
.
SentencePredictionDataConfig
(
input_path
=
input_path
,
seq_length
=
seq_length
,
global_batch_size
=
batch_size
,
label_type
=
'int'
,
label_name
=
(
'label_ids'
,
'next_sentence_labels'
))
dataset
=
loader
.
SentencePredictionDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_mask'
,
'input_type_ids'
,
'next_sentence_labels'
,
'label_ids'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'label_ids'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
'label_ids'
].
dtype
,
tf
.
int32
)
self
.
assertEqual
(
features
[
'next_sentence_labels'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
'next_sentence_labels'
].
dtype
,
tf
.
int32
)
class
SentencePredictionTfdsDataLoaderTest
(
tf
.
test
.
TestCase
,
class
SentencePredictionTfdsDataLoaderTest
(
tf
.
test
.
TestCase
,
...
@@ -170,13 +196,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
...
@@ -170,13 +196,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
lower_case
=
lower_case
,
lower_case
=
lower_case
,
vocab_file
=
vocab_file_path
)
vocab_file
=
vocab_file_path
)
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
,
labels
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
],
label_field
=
data_config
.
label_field
features
.
keys
())
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
label_field
]
.
shape
,
(
batch_size
,))
@
parameterized
.
parameters
(
True
,
False
)
@
parameterized
.
parameters
(
True
,
False
)
def
test_python_sentencepiece_preprocessing
(
self
,
use_tfds
):
def
test_python_sentencepiece_preprocessing
(
self
,
use_tfds
):
...
@@ -203,13 +231,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
...
@@ -203,13 +231,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file
=
sp_model_file_path
,
vocab_file
=
sp_model_file_path
,
)
)
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
,
labels
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
],
label_field
=
data_config
.
label_field
features
.
keys
())
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
label_field
]
.
shape
,
(
batch_size
,))
@
parameterized
.
parameters
(
True
,
False
)
@
parameterized
.
parameters
(
True
,
False
)
def
test_saved_model_preprocessing
(
self
,
use_tfds
):
def
test_saved_model_preprocessing
(
self
,
use_tfds
):
...
@@ -236,13 +266,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
...
@@ -236,13 +266,15 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
label_type
=
'int'
if
use_tfds
else
'float'
,
label_type
=
'int'
if
use_tfds
else
'float'
,
)
)
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
,
labels
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
],
label_field
=
data_config
.
label_field
features
.
keys
())
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
label_field
]
.
shape
,
(
batch_size
,))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/nlp/modeling/models/xlnet.py
View file @
a04d9e0e
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""XLNet models."""
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
from
typing
import
Any
,
Mapping
,
Union
from
typing
import
Any
,
Mapping
,
Optional
,
Union
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -99,7 +99,7 @@ class XLNetPretrainer(tf.keras.Model):
...
@@ -99,7 +99,7 @@ class XLNetPretrainer(tf.keras.Model):
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
mlm_activation
=
None
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
mlm_initializer
=
'glorot_uniform'
,
name
:
str
=
None
,
name
:
Optional
[
str
]
=
None
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
_config
=
{
self
.
_config
=
{
...
...
official/nlp/modeling/ops/sampling_module.py
View file @
a04d9e0e
...
@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -431,17 +431,17 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def
_continue_search
(
self
,
state
)
->
tf
.
Tensor
:
def
_continue_search
(
self
,
state
)
->
tf
.
Tensor
:
i
=
state
[
decoding_module
.
StateKeys
.
CUR_INDEX
]
i
=
state
[
decoding_module
.
StateKeys
.
CUR_INDEX
]
return
tf
.
less
(
i
,
self
.
max_decode_length
)
# Have we reached max decoding length?
not_at_end
=
tf
.
less
(
i
,
self
.
max_decode_length
)
# Have all sampled sequences reached an EOS?
all_has_eos
=
tf
.
reduce_all
(
state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
],
axis
=
None
,
name
=
"search_finish_cond"
)
return
tf
.
logical_and
(
not_at_end
,
tf
.
logical_not
(
all_has_eos
))
def
_finished_flags
(
self
,
topk_ids
,
state
)
->
tf
.
Tensor
:
def
_finished_flags
(
self
,
topk_ids
,
state
)
->
tf
.
Tensor
:
new_finished_flags
=
tf
.
equal
(
topk_ids
,
self
.
eos_id
)
new_finished_flags
=
tf
.
equal
(
topk_ids
,
self
.
eos_id
)
new_finished_flags
=
tf
.
logical_or
(
new_finished_flags
=
tf
.
logical_or
(
new_finished_flags
,
state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
])
new_finished_flags
,
state
[
decoding_module
.
StateKeys
.
FINISHED_FLAGS
])
return
new_finished_flags
return
new_finished_flags
official/nlp/projects/mobilebert/README.md
View file @
a04d9e0e
...
@@ -22,7 +22,7 @@ modeling library:
...
@@ -22,7 +22,7 @@ modeling library:
*
[
mobile_bert_encoder.py
](
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py
)
*
[
mobile_bert_encoder.py
](
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/mobile_bert_encoder.py
)
contains
`MobileBERTEncoder`
implementation.
contains
`MobileBERTEncoder`
implementation.
*
[
mobile_bert_layers.py
](
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py
)
*
[
mobile_bert_layers.py
](
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/mobile_bert_layers.py
)
contains
`MobileBertEmbedding`
,
`MobileBert
MaskedLM
`
and
`MobileBertMaskedLM`
contains
`MobileBertEmbedding`
,
`MobileBert
Transformer
`
and
`MobileBertMaskedLM`
implementation.
implementation.
## Pre-trained Models
## Pre-trained Models
...
...
official/nlp/projects/teams/README.md
0 → 100644
View file @
a04d9e0e
# TEAMS (Training ELECTRA Augmented with Multi-word Selection)
**Note:**
This project is working in progress and please stay tuned.
TEAMS is a text encoder pre-training method that simultaneously learns a
generator and a discriminator using multi-task learning. We propose a new
pre-training task, multi-word selection, and combine it with previous
pre-training tasks for efficient encoder pre-training. We also develop two
techniques, attention-based task-specific heads and partial layer sharing,
to further improve pre-training effectiveness.
Our academic paper
[
[1]
](
#1
)
which describes TEAMS in detail can be found here:
https://arxiv.org/abs/2106.00139.
## References
<a
id=
"1"
>
[1]
</a>
Jiaming Shen, Jialu Liu, Tianqi Liu, Cong Yu and Jiawei Han, "Training ELECTRA
Augmented with Multi-word Selection", Findings of the Association for
Computational Linguistics: ACL 2021.
official/nlp/tasks/sentence_prediction.py
View file @
a04d9e0e
...
@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
...
@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
if
params
.
metric_type
not
in
METRIC_TYPES
:
if
params
.
metric_type
not
in
METRIC_TYPES
:
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
self
.
metric_type
=
params
.
metric_type
self
.
metric_type
=
params
.
metric_type
if
hasattr
(
params
.
train_data
,
'label_field'
):
self
.
label_field
=
params
.
train_data
.
label_field
else
:
self
.
label_field
=
'label_ids'
def
build_model
(
self
):
def
build_model
(
self
):
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
...
@@ -95,11 +99,12 @@ class SentencePredictionTask(base_task.Task):
...
@@ -95,11 +99,12 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
label_ids
=
labels
[
self
.
label_field
]
if
self
.
task_config
.
model
.
num_classes
==
1
:
if
self
.
task_config
.
model
.
num_classes
==
1
:
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
labels
,
model_outputs
)
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
label
_id
s
,
model_outputs
)
else
:
else
:
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
label
_id
s
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
if
aux_losses
:
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
loss
+=
tf
.
add_n
(
aux_losses
)
...
@@ -120,7 +125,8 @@ class SentencePredictionTask(base_task.Task):
...
@@ -120,7 +125,8 @@ class SentencePredictionTask(base_task.Task):
y
=
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
)
y
=
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
)
else
:
else
:
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
return
x
,
y
x
[
self
.
label_field
]
=
y
return
x
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
...
@@ -142,16 +148,16 @@ class SentencePredictionTask(base_task.Task):
...
@@ -142,16 +148,16 @@ class SentencePredictionTask(base_task.Task):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
for
metric
in
metrics
:
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
metric
.
update_state
(
labels
[
self
.
label_field
]
,
model_outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
compiled_metrics
.
update_state
(
labels
,
model_outputs
)
compiled_metrics
.
update_state
(
labels
[
self
.
label_field
]
,
model_outputs
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
if
self
.
metric_type
==
'accuracy'
:
if
self
.
metric_type
==
'accuracy'
:
return
super
(
SentencePredictionTask
,
return
super
(
SentencePredictionTask
,
self
).
validation_step
(
inputs
,
model
,
metrics
)
self
).
validation_step
(
inputs
,
model
,
metrics
)
features
,
labels
=
inputs
features
,
labels
=
inputs
,
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
...
@@ -161,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
...
@@ -161,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction'
:
# Ensure one prediction along batch dimension.
'sentence_prediction'
:
# Ensure one prediction along batch dimension.
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
1
),
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
1
),
'labels'
:
'labels'
:
labels
,
labels
[
self
.
label_field
]
,
})
})
if
self
.
metric_type
==
'pearson_spearman_corr'
:
if
self
.
metric_type
==
'pearson_spearman_corr'
:
logs
.
update
({
logs
.
update
({
'sentence_prediction'
:
outputs
,
'sentence_prediction'
:
outputs
,
'labels'
:
labels
,
'labels'
:
labels
[
self
.
label_field
]
,
})
})
return
logs
return
logs
...
@@ -206,10 +212,10 @@ class SentencePredictionTask(base_task.Task):
...
@@ -206,10 +212,10 @@ class SentencePredictionTask(base_task.Task):
def
initialize
(
self
,
model
):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
if
not
ckpt_dir_or_file
:
return
return
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
pretrain2finetune_mapping
=
{
pretrain2finetune_mapping
=
{
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
...
@@ -250,7 +256,7 @@ def predict(task: SentencePredictionTask,
...
@@ -250,7 +256,7 @@ def predict(task: SentencePredictionTask,
def
predict_step
(
inputs
):
def
predict_step
(
inputs
):
"""Replicated prediction calculation."""
"""Replicated prediction calculation."""
x
,
_
=
inputs
x
=
inputs
example_id
=
x
.
pop
(
'example_id'
)
example_id
=
x
.
pop
(
'example_id'
)
outputs
=
task
.
inference_step
(
x
,
model
)
outputs
=
task
.
inference_step
(
x
,
model
)
return
dict
(
example_id
=
example_id
,
predictions
=
outputs
)
return
dict
(
example_id
=
example_id
,
predictions
=
outputs
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment