Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
772
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
287 additions
and
43 deletions
+287
-43
official/nlp/serving/export_savedmodel_test.py
official/nlp/serving/export_savedmodel_test.py
+1
-1
official/nlp/serving/export_savedmodel_util.py
official/nlp/serving/export_savedmodel_util.py
+1
-1
official/nlp/serving/serving_modules.py
official/nlp/serving/serving_modules.py
+53
-2
official/nlp/serving/serving_modules_test.py
official/nlp/serving/serving_modules_test.py
+75
-1
official/nlp/tasks/__init__.py
official/nlp/tasks/__init__.py
+1
-1
official/nlp/tasks/dual_encoder.py
official/nlp/tasks/dual_encoder.py
+6
-2
official/nlp/tasks/dual_encoder_test.py
official/nlp/tasks/dual_encoder_test.py
+2
-2
official/nlp/tasks/electra_task.py
official/nlp/tasks/electra_task.py
+1
-1
official/nlp/tasks/electra_task_test.py
official/nlp/tasks/electra_task_test.py
+1
-1
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+1
-1
official/nlp/tasks/masked_lm_determinism_test.py
official/nlp/tasks/masked_lm_determinism_test.py
+103
-0
official/nlp/tasks/masked_lm_test.py
official/nlp/tasks/masked_lm_test.py
+1
-1
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+5
-5
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+1
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+21
-12
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+10
-7
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+1
-1
official/nlp/tasks/tagging_test.py
official/nlp/tasks/tagging_test.py
+1
-1
official/nlp/tasks/translation.py
official/nlp/tasks/translation.py
+1
-1
official/nlp/tasks/translation_test.py
official/nlp/tasks/translation_test.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
772 of 772+
files are displayed.
Plain diff
Email patch
official/nlp/serving/export_savedmodel_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/serving/export_savedmodel_util.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/serving/serving_modules.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,10 +14,12 @@
"""Serving export modules for TF Model Garden NLP models."""
# pylint:disable=missing-class-docstring
import
dataclasses
from
typing
import
Dict
,
List
,
Optional
,
Text
import
dataclasses
import
tensorflow
as
tf
import
tensorflow_text
as
tf_text
from
official.core
import
export_base
from
official.modeling.hyperparams
import
base_config
from
official.nlp.data
import
sentence_prediction_dataloader
...
...
@@ -407,3 +409,52 @@ class Tagging(export_base.ExportModule):
signatures
[
signature_key
]
=
self
.
serve_examples
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
,
name
=
"examples"
))
return
signatures
class
Translation
(
export_base
.
ExportModule
):
"""The export module for the translation task."""
@
dataclasses
.
dataclass
class
Params
(
base_config
.
Config
):
sentencepiece_model_path
:
str
=
""
# Needs to be specified if padded_decode is True/on TPUs.
batch_size
:
Optional
[
int
]
=
None
def
__init__
(
self
,
params
,
model
:
tf
.
keras
.
Model
,
inference_step
=
None
):
super
().
__init__
(
params
,
model
,
inference_step
)
self
.
_sp_tokenizer
=
tf_text
.
SentencepieceTokenizer
(
model
=
tf
.
io
.
gfile
.
GFile
(
params
.
sentencepiece_model_path
,
"rb"
).
read
(),
add_eos
=
True
)
try
:
empty_str_tokenized
=
self
.
_sp_tokenizer
.
tokenize
(
""
).
numpy
()
except
tf
.
errors
.
InternalError
:
raise
ValueError
(
"EOS token not in tokenizer vocab."
"Please make sure the tokenizer generates a single token for an "
"empty string."
)
self
.
_eos_id
=
empty_str_tokenized
.
item
()
self
.
_batch_size
=
params
.
batch_size
@
tf
.
function
def
serve
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
return
self
.
inference_step
(
inputs
)
@
tf
.
function
def
serve_text
(
self
,
text
:
tf
.
Tensor
)
->
Dict
[
str
,
tf
.
Tensor
]:
tokenized
=
self
.
_sp_tokenizer
.
tokenize
(
text
).
to_tensor
(
0
)
return
self
.
_sp_tokenizer
.
detokenize
(
self
.
serve
({
"inputs"
:
tokenized
})[
"outputs"
])
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
signatures
=
{}
valid_keys
=
(
"serve_text"
)
for
func_key
,
signature_key
in
function_keys
.
items
():
if
func_key
not
in
valid_keys
:
raise
ValueError
(
"Invalid function key for the module: %s with key %s. "
"Valid keys are: %s"
%
(
self
.
__class__
,
func_key
,
valid_keys
))
if
func_key
==
"serve_text"
:
signatures
[
signature_key
]
=
self
.
serve_text
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
],
dtype
=
tf
.
string
,
name
=
"text"
))
return
signatures
official/nlp/serving/serving_modules_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,8 +15,12 @@
"""Tests for nlp.serving.serving_modules."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
sentencepiece
import
SentencePieceTrainer
from
official.core
import
export_base
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.serving
import
serving_modules
...
...
@@ -24,6 +28,7 @@ from official.nlp.tasks import masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
from
official.nlp.tasks
import
translation
def
_create_fake_serialized_examples
(
features_dict
):
...
...
@@ -59,6 +64,33 @@ def _create_fake_vocab_file(vocab_file_path):
outfile
.
write
(
"
\n
"
.
join
(
tokens
))
def
_train_sentencepiece
(
input_path
,
vocab_size
,
model_path
,
eos_id
=
1
):
argstr
=
" "
.
join
([
f
"--input=
{
input_path
}
"
,
f
"--vocab_size=
{
vocab_size
}
"
,
"--character_coverage=0.995"
,
f
"--model_prefix=
{
model_path
}
"
,
"--model_type=bpe"
,
"--bos_id=-1"
,
"--pad_id=0"
,
f
"--eos_id=
{
eos_id
}
"
,
"--unk_id=2"
])
SentencePieceTrainer
.
Train
(
argstr
)
def
_generate_line_file
(
filepath
,
lines
):
with
tf
.
io
.
gfile
.
GFile
(
filepath
,
"w"
)
as
f
:
for
l
in
lines
:
f
.
write
(
"{}
\n
"
.
format
(
l
))
def
_make_sentencepeice
(
output_dir
):
src_lines
=
[
"abc ede fg"
,
"bbcd ef a g"
,
"de f a a g"
]
tgt_lines
=
[
"dd cc a ef g"
,
"bcd ef a g"
,
"gef cd ba"
]
sentencepeice_input_path
=
os
.
path
.
join
(
output_dir
,
"inputs.txt"
)
_generate_line_file
(
sentencepeice_input_path
,
src_lines
+
tgt_lines
)
sentencepeice_model_prefix
=
os
.
path
.
join
(
output_dir
,
"sp"
)
_train_sentencepiece
(
sentencepeice_input_path
,
11
,
sentencepeice_model_prefix
)
sentencepeice_model_path
=
"{}.model"
.
format
(
sentencepeice_model_prefix
)
return
sentencepeice_model_path
class
ServingModulesTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
...
...
@@ -312,6 +344,48 @@ class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
with
self
.
assertRaises
(
ValueError
):
_
=
export_module
.
get_inference_signatures
({
"foo"
:
None
})
@
parameterized
.
parameters
(
(
False
,
None
),
(
True
,
2
))
def
test_translation
(
self
,
padded_decode
,
batch_size
):
sp_path
=
_make_sentencepeice
(
self
.
get_temp_dir
())
encdecoder
=
translation
.
EncDecoder
(
num_attention_heads
=
4
,
intermediate_size
=
256
)
config
=
translation
.
TranslationConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
encdecoder
,
decoder
=
encdecoder
,
embedding_width
=
256
,
padded_decode
=
padded_decode
,
decode_max_length
=
100
),
sentencepiece_model_path
=
sp_path
,
)
task
=
translation
.
TranslationTask
(
config
)
model
=
task
.
build_model
()
params
=
serving_modules
.
Translation
.
Params
(
sentencepiece_model_path
=
sp_path
,
batch_size
=
batch_size
)
export_module
=
serving_modules
.
Translation
(
params
=
params
,
model
=
model
)
functions
=
export_module
.
get_inference_signatures
({
"serve_text"
:
"serving_default"
})
outputs
=
functions
[
"serving_default"
](
tf
.
constant
([
"abcd"
,
"ef gh"
]))
self
.
assertEqual
(
outputs
.
shape
,
(
2
,))
self
.
assertEqual
(
outputs
.
dtype
,
tf
.
string
)
tmp_dir
=
self
.
get_temp_dir
()
tmp_dir
=
os
.
path
.
join
(
tmp_dir
,
"padded_decode"
,
str
(
padded_decode
))
export_base_dir
=
os
.
path
.
join
(
tmp_dir
,
"export"
)
ckpt_dir
=
os
.
path
.
join
(
tmp_dir
,
"ckpt"
)
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
ckpt_dir
)
export_dir
=
export_base
.
export
(
export_module
,
{
"serve_text"
:
"serving_default"
},
export_base_dir
,
ckpt_path
)
loaded
=
tf
.
saved_model
.
load
(
export_dir
)
infer
=
loaded
.
signatures
[
"serving_default"
]
out
=
infer
(
text
=
tf
.
constant
([
"abcd"
,
"ef gh"
]))
self
.
assertLen
(
out
[
"output_0"
],
2
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/tasks/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/dual_encoder.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -187,9 +187,13 @@ class DualEncoderTask(base_task.Task):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
logging
.
info
(
'Trying to load pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
if
ckpt_dir_or_file
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
logging
.
info
(
'No checkpoint file found from %s. Will not load.'
,
ckpt_dir_or_file
)
return
pretrain2finetune_mapping
=
{
...
...
official/nlp/tasks/dual_encoder_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,7 +19,7 @@ import os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.
nlp
.bert
import
configs
from
official.
legacy
.bert
import
configs
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
dual_encoder_dataloader
...
...
official/nlp/tasks/electra_task.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/electra_task_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/masked_lm.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/masked_lm_determinism_test.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests that masked LM models are deterministic when determinism is enabled."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
masked_lm
class
MLMTaskTest
(
tf
.
test
.
TestCase
):
def
_build_dataset
(
self
,
params
,
vocab_size
):
def
dummy_data
(
_
):
dummy_ids
=
tf
.
random
.
uniform
((
1
,
params
.
seq_length
),
maxval
=
vocab_size
,
dtype
=
tf
.
int32
)
dummy_mask
=
tf
.
ones
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
dummy_type_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
dummy_lm
=
tf
.
zeros
((
1
,
params
.
max_predictions_per_seq
),
dtype
=
tf
.
int32
)
return
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_mask
,
input_type_ids
=
dummy_type_ids
,
masked_lm_positions
=
dummy_lm
,
masked_lm_ids
=
dummy_lm
,
masked_lm_weights
=
tf
.
cast
(
dummy_lm
,
dtype
=
tf
.
float32
),
next_sentence_labels
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
))
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
_build_and_run_model
(
self
,
config
,
num_steps
=
5
):
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
self
.
_build_dataset
(
config
.
train_data
,
config
.
model
.
encoder
.
get
().
vocab_size
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
# Run training
for
_
in
range
(
num_steps
):
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
for
metric
in
metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
# Run validation
validation_logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
for
metric
in
metrics
:
validation_logs
[
metric
.
name
]
=
metric
.
result
()
return
logs
,
validation_logs
,
model
.
weights
def
test_task_determinism
(
self
):
config
=
masked_lm
.
MaskedLMConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
scale_loss
=
True
,
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
max_predictions_per_seq
=
20
,
seq_length
=
128
,
global_batch_size
=
1
))
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs1
,
validation_logs1
,
weights1
=
self
.
_build_and_run_model
(
config
)
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs2
,
validation_logs2
,
weights2
=
self
.
_build_and_run_model
(
config
)
self
.
assertEqual
(
logs1
[
"loss"
],
logs2
[
"loss"
])
self
.
assertEqual
(
validation_logs1
[
"loss"
],
validation_logs2
[
"loss"
])
for
weight1
,
weight2
in
zip
(
weights1
,
weights2
):
self
.
assertAllEqual
(
weight1
,
weight2
)
if
__name__
==
"__main__"
:
tf
.
config
.
experimental
.
enable_op_determinism
()
tf
.
test
.
main
()
official/nlp/tasks/masked_lm_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/question_answering.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,13 +13,13 @@
# limitations under the License.
"""Question answering task."""
import
dataclasses
import
functools
import
json
import
os
from
typing
import
List
,
Optional
from
absl
import
logging
import
dataclasses
import
orbit
import
tensorflow
as
tf
...
...
@@ -27,15 +27,15 @@ from official.core import base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
base_config
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v2_0
from
official.nlp.bert
import
tokenization
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
utils
from
official.nlp.tools
import
squad_evaluate_v1_1
from
official.nlp.tools
import
squad_evaluate_v2_0
from
official.nlp.tools
import
tokenization
@
dataclasses
.
dataclass
...
...
official/nlp/tasks/question_answering_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/sentence_prediction.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -34,7 +34,7 @@ from official.nlp.modeling import models
from
official.nlp.tasks
import
utils
METRIC_TYPES
=
frozenset
(
[
'accuracy'
,
'matthews_corrcoef'
,
'pearson_spearman_corr'
])
[
'accuracy'
,
'f1'
,
'matthews_corrcoef'
,
'pearson_spearman_corr'
])
@
dataclasses
.
dataclass
...
...
@@ -165,14 +165,17 @@ class SentencePredictionTask(base_task.Task):
compiled_metrics
.
update_state
(
labels
[
self
.
label_field
],
model_outputs
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
if
self
.
metric_type
==
'accuracy'
:
return
super
(
SentencePredictionTask
,
self
).
validation_step
(
inputs
,
model
,
metrics
)
features
,
labels
=
inputs
,
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
if
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
or
[]})
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
if
self
.
metric_type
==
'matthews_corrcoef'
:
logs
.
update
({
'sentence_prediction'
:
# Ensure one prediction along batch dimension.
...
...
@@ -180,7 +183,7 @@ class SentencePredictionTask(base_task.Task):
'labels'
:
labels
[
self
.
label_field
],
})
if
self
.
metric_type
==
'pearson_spearman_corr'
:
else
:
logs
.
update
({
'sentence_prediction'
:
outputs
,
'labels'
:
labels
[
self
.
label_field
],
...
...
@@ -202,18 +205,20 @@ class SentencePredictionTask(base_task.Task):
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
if
self
.
metric_type
==
'accuracy'
:
return
None
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
if
self
.
metric_type
==
'f1'
:
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
return
{
self
.
metric_type
:
sklearn_metrics
.
f1_score
(
labels
,
preds
)}
elif
self
.
metric_type
==
'matthews_corrcoef'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
reshape
(
preds
,
-
1
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
reshape
(
labels
,
-
1
)
return
{
self
.
metric_type
:
sklearn_metrics
.
matthews_corrcoef
(
preds
,
labels
)
}
elif
self
.
metric_type
==
'pearson_spearman_corr'
:
preds
=
np
.
concatenate
(
aggregated_logs
[
'sentence_prediction'
],
axis
=
0
)
preds
=
np
.
reshape
(
preds
,
-
1
)
labels
=
np
.
concatenate
(
aggregated_logs
[
'labels'
],
axis
=
0
)
labels
=
np
.
reshape
(
labels
,
-
1
)
pearson_corr
=
stats
.
pearsonr
(
preds
,
labels
)[
0
]
spearman_corr
=
stats
.
spearmanr
(
preds
,
labels
)[
0
]
...
...
@@ -223,10 +228,14 @@ class SentencePredictionTask(base_task.Task):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
logging
.
info
(
'Trying to load pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
if
ckpt_dir_or_file
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
logging
.
info
(
'No checkpoint file found from %s. Will not load.'
,
ckpt_dir_or_file
)
return
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
pretrain2finetune_mapping
=
{
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -32,10 +32,12 @@ def _create_fake_dataset(output_path, seq_length, num_classes, num_examples):
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
np
.
ravel
(
values
)))
def
create_float_feature
(
values
):
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
np
.
ravel
(
values
)))
for
i
in
range
(
num_examples
):
features
=
{}
...
...
@@ -81,7 +83,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
functools
.
partial
(
task
.
build_inputs
,
config
.
train_data
))
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
l
r
=
0.1
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
l
earning_rate
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
model
.
save
(
os
.
path
.
join
(
self
.
get_temp_dir
(),
"saved_model"
))
return
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
...
...
@@ -118,7 +120,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
l
r
=
0.1
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
l
earning_rate
=
0.1
)
task
.
initialize
(
model
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
...
...
@@ -149,7 +151,7 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
l
r
=
0.1
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
l
earning_rate
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
...
...
@@ -160,7 +162,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertLess
(
loss
,
1.0
)
@
parameterized
.
parameters
((
"matthews_corrcoef"
,
2
),
(
"pearson_spearman_corr"
,
1
))
(
"pearson_spearman_corr"
,
1
),
(
"f1"
,
2
))
def
test_np_metrics
(
self
,
metric_type
,
num_classes
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
metric_type
=
metric_type
,
...
...
official/nlp/tasks/tagging.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/tagging_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/translation.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/nlp/tasks/translation_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
…
22
23
24
25
26
27
28
29
30
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment