Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9b4be3ab
Commit
9b4be3ab
authored
Jul 08, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jul 08, 2020
Browse files
Add a predict method in tagging task.
PiperOrigin-RevId: 320145633
parent
2a38d9a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
370 deletions
+7
-370
official/nlp/data/tagging_data_loader.py
official/nlp/data/tagging_data_loader.py
+7
-0
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+0
-225
official/nlp/tasks/tagging_test.py
official/nlp/tasks/tagging_test.py
+0
-145
No files found.
official/nlp/data/tagging_data_loader.py
View file @
9b4be3ab
...
...
@@ -28,6 +28,7 @@ class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training
:
bool
=
True
seq_length
:
int
=
128
include_sentence_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
TaggingDataConfig
)
...
...
@@ -37,6 +38,7 @@ class TaggingDataLoader:
def
__init__
(
self
,
params
:
TaggingDataConfig
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_include_sentence_id
=
params
.
include_sentence_id
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
...
...
@@ -46,6 +48,9 @@ class TaggingDataLoader:
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
if
self
.
_include_sentence_id
:
name_to_features
[
'sentence_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...
...
@@ -65,6 +70,8 @@ class TaggingDataLoader:
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
}
if
self
.
_include_sentence_id
:
x
[
'sentence_id'
]
=
record
[
'sentence_id'
]
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
...
...
official/nlp/tasks/tagging.py
deleted
100644 → 0
View file @
2a38d9a4
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tagging (e.g., NER/POS) task."""
import
logging
from
typing
import
List
,
Optional
import
dataclasses
from
seqeval
import
metrics
as
seqeval_metrics
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.core
import
base_task
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
data_loader_factory
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
utils
@
dataclasses
.
dataclass
class
ModelConfig
(
base_config
.
Config
):
"""A base span labeler configuration."""
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
head_dropout
:
float
=
0.1
head_initializer_range
:
float
=
0.02
@
dataclasses
.
dataclass
class
TaggingConfig
(
cfg
.
TaskConfig
):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
model
:
ModelConfig
=
ModelConfig
()
# The real class names, the order of which should match real label id.
# Note that a word may be tokenized into multiple word_pieces tokens, and
# we asssume the real label id (non-negative) is assigned to the first token
# of the word, and a negative label id is assigned to the remaining tokens.
# The negative label id will not contribute to loss and metrics.
class_names
:
Optional
[
List
[
str
]]
=
None
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
def
_masked_labels_and_weights
(
y_true
):
"""Masks negative values from token level labels.
Args:
y_true: Token labels, typically shape (batch_size, seq_len), where tokens
with negative labels should be ignored during loss/accuracy calculation.
Returns:
(masked_y_true, masked_weights) where `masked_y_true` is the input
with each negative label replaced with zero and `masked_weights` is 0.0
where negative labels were replaced and 1.0 for original labels.
"""
# Ignore the classes of tokens with negative values.
mask
=
tf
.
greater_equal
(
y_true
,
0
)
# Replace negative labels, which are out of bounds for some loss functions,
# with zero.
masked_y_true
=
tf
.
where
(
mask
,
y_true
,
0
)
return
masked_y_true
,
tf
.
cast
(
mask
,
tf
.
float32
)
@
base_task
.
register_task_cls
(
TaggingConfig
)
class
TaggingTask
(
base_task
.
Task
):
"""Task object for tagging (e.g., NER or POS)."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
,
logging_dir
=
None
):
super
(
TaggingTask
,
self
).
__init__
(
params
,
logging_dir
)
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
not
params
.
class_names
:
raise
ValueError
(
'TaggingConfig.class_names cannot be empty.'
)
if
params
.
hub_module_url
:
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
self
.
_hub_module
=
None
def
build_model
(
self
):
if
self
.
_hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
else
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
self
.
task_config
.
model
.
encoder
)
return
models
.
BertTokenClassifier
(
network
=
encoder_network
,
num_classes
=
len
(
self
.
task_config
.
class_names
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
task_config
.
model
.
head_initializer_range
),
dropout_rate
=
self
.
task_config
.
model
.
head_dropout
,
output
=
'logits'
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
model_outputs
=
tf
.
cast
(
model_outputs
,
tf
.
float32
)
masked_labels
,
masked_weights
=
_masked_labels_and_weights
(
labels
)
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
masked_labels
,
model_outputs
,
from_logits
=
True
)
numerator_loss
=
tf
.
reduce_sum
(
loss
*
masked_weights
)
denominator_loss
=
tf
.
reduce_sum
(
masked_weights
)
loss
=
tf
.
math
.
divide_no_nan
(
numerator_loss
,
denominator_loss
)
return
loss
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for sentence_prediction task."""
if
params
.
input_path
==
'dummy'
:
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
# Include some label_id as -1, which will be ignored in loss/metrics.
y
=
tf
.
random
.
uniform
(
shape
=
(
1
,
params
.
seq_length
),
minval
=-
1
,
maxval
=
len
(
self
.
task_config
.
class_names
),
dtype
=
tf
.
dtypes
.
int32
)
return
(
x
,
y
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
data_loader_factory
.
get_data_loader
(
params
).
load
(
input_context
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
)
# Negative label ids are padding labels which should be ignored.
real_label_index
=
tf
.
where
(
tf
.
greater_equal
(
labels
,
0
))
predict_ids
=
tf
.
math
.
argmax
(
outputs
,
axis
=-
1
)
predict_ids
=
tf
.
gather_nd
(
predict_ids
,
real_label_index
)
label_ids
=
tf
.
gather_nd
(
labels
,
real_label_index
)
return
{
self
.
loss
:
loss
,
'predict_ids'
:
predict_ids
,
'label_ids'
:
label_ids
,
}
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
"""Aggregates over logs returned from a validation step."""
if
state
is
None
:
state
=
{
'predict_class'
:
[],
'label_class'
:
[]}
def
id_to_class_name
(
batched_ids
):
class_names
=
[]
for
per_example_ids
in
batched_ids
:
class_names
.
append
([])
for
per_token_id
in
per_example_ids
.
numpy
().
tolist
():
class_names
[
-
1
].
append
(
self
.
task_config
.
class_names
[
per_token_id
])
return
class_names
# Convert id to class names, because `seqeval_metrics` relies on the class
# name to decide IOB tags.
state
[
'predict_class'
].
extend
(
id_to_class_name
(
step_outputs
[
'predict_ids'
]))
state
[
'label_class'
].
extend
(
id_to_class_name
(
step_outputs
[
'label_ids'
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
"""Reduces aggregated logs over validation steps."""
label_class
=
aggregated_logs
[
'label_class'
]
predict_class
=
aggregated_logs
[
'predict_class'
]
return
{
'f1'
:
seqeval_metrics
.
f1_score
(
label_class
,
predict_class
),
'precision'
:
seqeval_metrics
.
precision_score
(
label_class
,
predict_class
),
'recall'
:
seqeval_metrics
.
recall_score
(
label_class
,
predict_class
),
'accuracy'
:
seqeval_metrics
.
accuracy_score
(
label_class
,
predict_class
),
}
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/tagging_test.py
deleted
100644 → 0
View file @
2a38d9a4
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.tagging."""
import
functools
import
os
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
export_tfhub
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
tagging_data_loader
from
official.nlp.tasks
import
tagging
class
TaggingTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
TaggingTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
self
.
_train_data_config
=
tagging_data_loader
.
TaggingDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
def
_run_task
(
self
,
config
):
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
strategy
.
experimental_distribute_datasets_from_function
(
functools
.
partial
(
task
.
build_inputs
,
config
.
train_data
))
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
def
test_task
(
self
):
# Saves a checkpoint.
encoder
=
encoders
.
instantiate_encoder_from_cfg
(
self
.
_encoder_config
)
ckpt
=
tf
.
train
.
Checkpoint
(
encoder
=
encoder
)
saved_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
config
=
tagging
.
TaggingConfig
(
init_checkpoint
=
saved_path
,
model
=
tagging
.
ModelConfig
(
encoder
=
self
.
_encoder_config
),
train_data
=
self
.
_train_data_config
,
class_names
=
[
"O"
,
"B-PER"
,
"I-PER"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
task
.
initialize
(
model
)
def
test_task_with_fit
(
self
):
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
self
.
_encoder_config
),
train_data
=
self
.
_train_data_config
,
class_names
=
[
"O"
,
"B-PER"
,
"I-PER"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
compile_model
(
model
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
),
train_step
=
task
.
train_step
,
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
"accuracy"
)])
dataset
=
task
.
build_inputs
(
config
.
train_data
)
logs
=
model
.
fit
(
dataset
,
epochs
=
1
,
steps_per_epoch
=
2
)
self
.
assertIn
(
"loss"
,
logs
.
history
)
self
.
assertIn
(
"accuracy"
,
logs
.
history
)
def
_export_bert_tfhub
(
self
):
bert_config
=
configs
.
BertConfig
(
vocab_size
=
30522
,
hidden_size
=
16
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_hidden_layers
=
1
)
_
,
encoder
=
export_tfhub
.
create_bert_model
(
bert_config
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
vocab_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"uncased_vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
vocab_file
,
"w"
)
as
f
:
f
.
write
(
"dummy content"
)
hub_destination
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"hub"
)
export_tfhub
.
export_bert_tfhub
(
bert_config
,
model_checkpoint_path
,
hub_destination
,
vocab_file
)
return
hub_destination
def
test_task_with_hub
(
self
):
hub_module_url
=
self
.
_export_bert_tfhub
()
config
=
tagging
.
TaggingConfig
(
hub_module_url
=
hub_module_url
,
class_names
=
[
"O"
,
"B-PER"
,
"I-PER"
],
train_data
=
self
.
_train_data_config
)
self
.
_run_task
(
config
)
def
test_seqeval_metrics
(
self
):
config
=
tagging
.
TaggingConfig
(
model
=
tagging
.
ModelConfig
(
encoder
=
self
.
_encoder_config
),
train_data
=
self
.
_train_data_config
,
class_names
=
[
"O"
,
"B-PER"
,
"I-PER"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
strategy
=
tf
.
distribute
.
get_strategy
()
distributed_outputs
=
strategy
.
run
(
functools
.
partial
(
task
.
validation_step
,
model
=
model
),
args
=
(
next
(
iterator
),))
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
distributed_outputs
)
aggregated
=
task
.
aggregate_logs
(
step_outputs
=
outputs
)
aggregated
=
task
.
aggregate_logs
(
state
=
aggregated
,
step_outputs
=
outputs
)
self
.
assertCountEqual
({
"f1"
,
"precision"
,
"recall"
,
"accuracy"
},
task
.
reduce_aggregated_logs
(
aggregated
).
keys
())
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment