Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
02242bc8
"vscode:/vscode.git/clone" did not exist on "66d883f2118cbaf925fcbfd130cbdc5d2387073d"
Commit
02242bc8
authored
Jun 25, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 25, 2020
Browse files
Internal change
PiperOrigin-RevId: 318387106
parent
7ebcbe20
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
107 additions
and
28 deletions
+107
-28
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+75
-23
official/nlp/tasks/tagging_test.py
official/nlp/tasks/tagging_test.py
+24
-3
official/pip_package/setup.py
official/pip_package/setup.py
+3
-0
official/requirements.txt
official/requirements.txt
+5
-2
No files found.
official/nlp/tasks/tagging.py
View file @
02242bc8
...
@@ -15,7 +15,12 @@
...
@@ -15,7 +15,12 @@
# ==============================================================================
# ==============================================================================
"""Tagging (e.g., NER/POS) task."""
"""Tagging (e.g., NER/POS) task."""
import
logging
import
logging
from
typing
import
List
,
Optional
import
dataclasses
import
dataclasses
from
seqeval
import
metrics
as
seqeval_metrics
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
...
@@ -36,12 +41,12 @@ class TaggingConfig(cfg.TaskConfig):
...
@@ -36,12 +41,12 @@ class TaggingConfig(cfg.TaskConfig):
model
:
encoders
.
TransformerEncoderConfig
=
(
model
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
encoders
.
TransformerEncoderConfig
())
# The
number of
real la
bels. Note that a word may be tokenized into
# The real
c
la
ss names, the order of which should match real label id.
#
multiple word_pieces tokens, and we asssume the real label id (non-negative)
#
Note that a word may be tokenized into multiple word_pieces tokens, and
#
is
ass
igned to the first token of the word, and a negative label id is
#
we
ass
sume the real label id (non-negative) is assigned to the first token
#
assigned to the remaining tokens. The negative label id will not contribute
#
of the word, and a negative label id is assigned to the remaining tokens.
# to loss and metrics.
#
The negative label id will not contribute
to loss and metrics.
num_
class
es
:
int
=
0
class
_names
:
Optional
[
List
[
str
]]
=
None
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
...
@@ -75,8 +80,8 @@ class TaggingTask(base_task.Task):
...
@@ -75,8 +80,8 @@ class TaggingTask(base_task.Task):
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
'`init_checkpoint` can be specified.'
)
if
params
.
num_
class
es
==
0
:
if
not
params
.
class
_names
:
raise
ValueError
(
'TaggingConfig.
num_
classes cannot be
0
.'
)
raise
ValueError
(
'TaggingConfig.class
_nam
es cannot be
empty
.'
)
if
params
.
hub_module_url
:
if
params
.
hub_module_url
:
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
...
@@ -92,7 +97,7 @@ class TaggingTask(base_task.Task):
...
@@ -92,7 +97,7 @@ class TaggingTask(base_task.Task):
return
models
.
BertTokenClassifier
(
return
models
.
BertTokenClassifier
(
network
=
encoder_network
,
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
num_
classes
,
num_classes
=
len
(
self
.
task_config
.
class
_nam
es
)
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
task_config
.
model
.
initializer_range
),
stddev
=
self
.
task_config
.
model
.
initializer_range
),
dropout_rate
=
self
.
task_config
.
model
.
dropout_rate
,
dropout_rate
=
self
.
task_config
.
model
.
dropout_rate
,
...
@@ -123,7 +128,7 @@ class TaggingTask(base_task.Task):
...
@@ -123,7 +128,7 @@ class TaggingTask(base_task.Task):
y
=
tf
.
random
.
uniform
(
y
=
tf
.
random
.
uniform
(
shape
=
(
1
,
params
.
seq_length
),
shape
=
(
1
,
params
.
seq_length
),
minval
=-
1
,
minval
=-
1
,
maxval
=
self
.
task_config
.
num_
classes
,
maxval
=
len
(
self
.
task_config
.
class
_nam
es
)
,
dtype
=
tf
.
dtypes
.
int32
)
dtype
=
tf
.
dtypes
.
int32
)
return
(
x
,
y
)
return
(
x
,
y
)
...
@@ -136,19 +141,66 @@ class TaggingTask(base_task.Task):
...
@@ -136,19 +141,66 @@ class TaggingTask(base_task.Task):
dataset
=
tagging_data_loader
.
TaggingDataLoader
(
params
).
load
(
input_context
)
dataset
=
tagging_data_loader
.
TaggingDataLoader
(
params
).
load
(
input_context
)
return
dataset
return
dataset
def
build_metrics
(
self
,
training
=
None
):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
del
training
"""Validatation step.
# TODO(chendouble): evaluate using seqeval's f1/precision/recall.
return
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
)]
Args:
inputs: a dictionary of input tensors.
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
model: the keras.Model.
masked_labels
,
masked_weights
=
_masked_labels_and_weights
(
labels
)
metrics: a nested structure of metrics objects.
for
metric
in
metrics
:
metric
.
update_state
(
masked_labels
,
model_outputs
,
masked_weights
)
Returns:
A dictionary of logs.
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
"""
masked_labels
,
masked_weights
=
_masked_labels_and_weights
(
labels
)
features
,
labels
=
inputs
compiled_metrics
.
update_state
(
masked_labels
,
model_outputs
,
masked_weights
)
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
)
# Negative label ids are padding labels which should be ignored.
real_label_index
=
tf
.
where
(
tf
.
greater_equal
(
labels
,
0
))
predict_ids
=
tf
.
math
.
argmax
(
outputs
,
axis
=-
1
)
predict_ids
=
tf
.
gather_nd
(
predict_ids
,
real_label_index
)
label_ids
=
tf
.
gather_nd
(
labels
,
real_label_index
)
return
{
self
.
loss
:
loss
,
'predict_ids'
:
predict_ids
,
'label_ids'
:
label_ids
,
}
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
"""Aggregates over logs returned from a validation step."""
if
state
is
None
:
state
=
{
'predict_class'
:
[],
'label_class'
:
[]}
def
id_to_class_name
(
batched_ids
):
class_names
=
[]
for
per_example_ids
in
batched_ids
:
class_names
.
append
([])
for
per_token_id
in
per_example_ids
.
numpy
().
tolist
():
class_names
[
-
1
].
append
(
self
.
task_config
.
class_names
[
per_token_id
])
return
class_names
# Convert id to class names, because `seqeval_metrics` relies on the class
# name to decide IOB tags.
state
[
'predict_class'
].
extend
(
id_to_class_name
(
step_outputs
[
'predict_ids'
]))
state
[
'label_class'
].
extend
(
id_to_class_name
(
step_outputs
[
'label_ids'
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
"""Reduces aggregated logs over validation steps."""
label_class
=
aggregated_logs
[
'label_class'
]
predict_class
=
aggregated_logs
[
'predict_class'
]
return
{
'f1'
:
seqeval_metrics
.
f1_score
(
label_class
,
predict_class
),
'precision'
:
seqeval_metrics
.
precision_score
(
label_class
,
predict_class
),
'recall'
:
seqeval_metrics
.
recall_score
(
label_class
,
predict_class
),
'accuracy'
:
seqeval_metrics
.
accuracy_score
(
label_class
,
predict_class
),
}
def
initialize
(
self
,
model
):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
...
...
official/nlp/tasks/tagging_test.py
View file @
02242bc8
...
@@ -58,7 +58,7 @@ class TaggingTest(tf.test.TestCase):
...
@@ -58,7 +58,7 @@ class TaggingTest(tf.test.TestCase):
init_checkpoint
=
saved_path
,
init_checkpoint
=
saved_path
,
model
=
self
.
_encoder_config
,
model
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
,
train_data
=
self
.
_train_data_config
,
num_
class
es
=
3
)
class
_names
=
[
"O"
,
"B-PER"
,
"I-PER"
]
)
task
=
tagging
.
TaggingTask
(
config
)
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
metrics
=
task
.
build_metrics
()
...
@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase):
...
@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase):
config
=
tagging
.
TaggingConfig
(
config
=
tagging
.
TaggingConfig
(
model
=
self
.
_encoder_config
,
model
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
,
train_data
=
self
.
_train_data_config
,
num_
class
es
=
3
)
class
_names
=
[
"O"
,
"B-PER"
,
"I-PER"
]
)
task
=
tagging
.
TaggingTask
(
config
)
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
...
@@ -116,10 +116,31 @@ class TaggingTest(tf.test.TestCase):
...
@@ -116,10 +116,31 @@ class TaggingTest(tf.test.TestCase):
config
=
tagging
.
TaggingConfig
(
config
=
tagging
.
TaggingConfig
(
hub_module_url
=
hub_module_url
,
hub_module_url
=
hub_module_url
,
model
=
self
.
_encoder_config
,
model
=
self
.
_encoder_config
,
num_
class
es
=
4
,
class
_names
=
[
"O"
,
"B-PER"
,
"I-PER"
]
,
train_data
=
self
.
_train_data_config
)
train_data
=
self
.
_train_data_config
)
self
.
_run_task
(
config
)
self
.
_run_task
(
config
)
def
test_seqeval_metrics
(
self
):
config
=
tagging
.
TaggingConfig
(
model
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
,
class_names
=
[
"O"
,
"B-PER"
,
"I-PER"
])
task
=
tagging
.
TaggingTask
(
config
)
model
=
task
.
build_model
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
strategy
=
tf
.
distribute
.
get_strategy
()
distributed_outputs
=
strategy
.
run
(
functools
.
partial
(
task
.
validation_step
,
model
=
model
),
args
=
(
next
(
iterator
),))
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
distributed_outputs
)
aggregated
=
task
.
aggregate_logs
(
step_outputs
=
outputs
)
aggregated
=
task
.
aggregate_logs
(
state
=
aggregated
,
step_outputs
=
outputs
)
self
.
assertCountEqual
({
"f1"
,
"precision"
,
"recall"
,
"accuracy"
},
task
.
reduce_aggregated_logs
(
aggregated
).
keys
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/pip_package/setup.py
View file @
02242bc8
...
@@ -45,6 +45,9 @@ def _get_requirements():
...
@@ -45,6 +45,9 @@ def _get_requirements():
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../requirements.txt'
),
'r'
)
as
f
:
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../requirements.txt'
),
'r'
)
as
f
:
for
line
in
f
:
for
line
in
f
:
package_name
=
line
.
strip
()
package_name
=
line
.
strip
()
# Skip empty line or comments starting with "#".
if
not
package_name
or
package_name
[
0
]
==
'#'
:
continue
if
package_name
.
startswith
(
'-e '
):
if
package_name
.
startswith
(
'-e '
):
dependency_links_tmp
.
append
(
package_name
[
3
:].
strip
())
dependency_links_tmp
.
append
(
package_name
[
3
:].
strip
())
else
:
else
:
...
...
official/requirements.txt
View file @
02242bc8
...
@@ -16,10 +16,13 @@ dataclasses
...
@@ -16,10 +16,13 @@ dataclasses
gin-config
gin-config
tf_slim>=1.1.0
tf_slim>=1.1.0
typing
typing
sentencepiece
Cython
Cython
matplotlib
matplotlib
opencv-python-headless
pyyaml
pyyaml
# CV related dependencies
opencv-python-headless
Pillow
Pillow
-e git+https://github.com/cocodataset/cocoapi#egg=pycocotools&subdirectory=PythonAPI
-e git+https://github.com/cocodataset/cocoapi#egg=pycocotools&subdirectory=PythonAPI
# NLP related dependencies
seqeval
sentencepiece
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment