Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
87e4768e
Commit
87e4768e
authored
Oct 13, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 13, 2020
Browse files
Internal change
PiperOrigin-RevId: 336991213
parent
c6970b7f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
38 deletions
+31
-38
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+11
-11
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+11
-12
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+7
-13
official/nlp/train_ctl_continuous_finetune.py
official/nlp/train_ctl_continuous_finetune.py
+2
-2
No files found.
official/nlp/tasks/question_answering.py
View file @
87e4768e
...
@@ -63,15 +63,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
...
@@ -63,15 +63,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
class
QuestionAnsweringTask
(
base_task
.
Task
):
class
QuestionAnsweringTask
(
base_task
.
Task
):
"""Task object for question answering."""
"""Task object for question answering."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
,
logging_dir
=
None
):
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
=
None
,
name
=
None
):
super
(
QuestionAnsweringTask
,
self
).
__init__
(
params
,
logging_dir
)
super
().
__init__
(
params
,
logging_dir
,
name
=
name
)
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
params
.
hub_module_url
:
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
self
.
_hub_module
=
None
if
params
.
validation_data
.
tokenization
==
'WordPiece'
:
if
params
.
validation_data
.
tokenization
==
'WordPiece'
:
self
.
squad_lib
=
squad_lib_wp
self
.
squad_lib
=
squad_lib_wp
...
@@ -90,8 +83,15 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -90,8 +83,15 @@ class QuestionAnsweringTask(base_task.Task):
self
.
_tf_record_input_path
=
eval_input_path
self
.
_tf_record_input_path
=
eval_input_path
def
build_model
(
self
):
def
build_model
(
self
):
if
self
.
_hub_module
:
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
self
.
task_config
.
hub_module_url
:
hub_module
=
hub
.
load
(
self
.
task_config
.
hub_module_url
)
else
:
hub_module
=
None
if
hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
hub_module
)
else
:
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
...
...
official/nlp/tasks/sentence_prediction.py
View file @
87e4768e
...
@@ -66,23 +66,22 @@ class SentencePredictionConfig(cfg.TaskConfig):
...
@@ -66,23 +66,22 @@ class SentencePredictionConfig(cfg.TaskConfig):
class
SentencePredictionTask
(
base_task
.
Task
):
class
SentencePredictionTask
(
base_task
.
Task
):
"""Task object for sentence_prediction."""
"""Task object for sentence_prediction."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
,
logging_dir
=
None
):
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
=
None
,
name
=
None
):
super
(
SentencePredictionTask
,
self
).
__init__
(
params
,
logging_dir
)
super
().
__init__
(
params
,
logging_dir
,
name
=
name
)
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
params
.
hub_module_url
:
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
self
.
_hub_module
=
None
if
params
.
metric_type
not
in
METRIC_TYPES
:
if
params
.
metric_type
not
in
METRIC_TYPES
:
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
self
.
metric_type
=
params
.
metric_type
self
.
metric_type
=
params
.
metric_type
def
build_model
(
self
):
def
build_model
(
self
):
if
self
.
_hub_module
:
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
self
.
task_config
.
hub_module_url
:
hub_module
=
hub
.
load
(
self
.
task_config
.
hub_module_url
)
else
:
hub_module
=
None
if
hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
hub_module
)
else
:
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
...
...
official/nlp/tasks/tagging.py
View file @
87e4768e
...
@@ -84,22 +84,16 @@ def _masked_labels_and_weights(y_true):
...
@@ -84,22 +84,16 @@ def _masked_labels_and_weights(y_true):
class
TaggingTask
(
base_task
.
Task
):
class
TaggingTask
(
base_task
.
Task
):
"""Task object for tagging (e.g., NER or POS)."""
"""Task object for tagging (e.g., NER or POS)."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
,
logging_dir
=
None
):
def
build_model
(
self
):
super
(
TaggingTask
,
self
).
__init__
(
params
,
logging_dir
)
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
'`init_checkpoint` can be specified.'
)
if
not
params
.
class_names
:
if
self
.
task_config
.
hub_module_url
:
raise
ValueError
(
'TaggingConfig.class_names cannot be empty.'
)
hub_module
=
hub
.
load
(
self
.
task_config
.
hub_module_url
)
if
params
.
hub_module_url
:
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
else
:
self
.
_hub_module
=
None
hub_module
=
None
if
hub_module
:
def
build_model
(
self
):
encoder_network
=
utils
.
get_encoder_from_hub
(
hub_module
)
if
self
.
_hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
else
:
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
...
...
official/nlp/train_ctl_continuous_finetune.py
View file @
87e4768e
...
@@ -150,9 +150,9 @@ def run_continuous_finetune(
...
@@ -150,9 +150,9 @@ def run_continuous_finetune(
train_utils
.
write_json_summary
(
model_dir
,
global_step
,
eval_metrics
)
train_utils
.
write_json_summary
(
model_dir
,
global_step
,
eval_metrics
)
if
not
os
.
path
.
basename
(
model_dir
):
# if model_dir.endswith('/')
if
not
os
.
path
.
basename
(
model_dir
):
# if model_dir.endswith('/')
summary_grp
=
os
.
path
.
dirname
(
model_dir
)
+
'_'
+
task
.
__class__
.
__
name
__
summary_grp
=
os
.
path
.
dirname
(
model_dir
)
+
'_'
+
task
.
name
else
:
else
:
summary_grp
=
os
.
path
.
basename
(
model_dir
)
+
'_'
+
task
.
__class__
.
__
name
__
summary_grp
=
os
.
path
.
basename
(
model_dir
)
+
'_'
+
task
.
name
summaries
=
{}
summaries
=
{}
for
name
,
value
in
eval_metrics
.
items
():
for
name
,
value
in
eval_metrics
.
items
():
summaries
[
summary_grp
+
'/'
+
name
]
=
value
summaries
[
summary_grp
+
'/'
+
name
]
=
value
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment