Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
58d19c67
Commit
58d19c67
authored
Jun 30, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 319162455
parent
574455f5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
10 deletions
+29
-10
official/core/base_task.py
official/core/base_task.py
+13
-1
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+12
-5
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+2
-2
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+2
-2
No files found.
official/core/base_task.py
View file @
58d19c67
...
...
@@ -37,13 +37,25 @@ class Task(tf.Module):
# Special keys in train/validate step returned logs.
loss
=
"loss"
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
):
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
:
str
=
None
):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
self
.
_task_config
=
params
self
.
_logging_dir
=
logging_dir
@
property
def
task_config
(
self
)
->
cfg
.
TaskConfig
:
return
self
.
_task_config
@
property
def
logging_dir
(
self
)
->
str
:
return
self
.
_logging_dir
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""A callback function used as CheckpointManager's init_fn.
...
...
official/nlp/tasks/question_answering.py
View file @
58d19c67
...
...
@@ -54,8 +54,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
class
QuestionAnsweringTask
(
base_task
.
Task
):
"""Task object for question answering."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
):
super
(
QuestionAnsweringTask
,
self
).
__init__
(
params
)
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
,
logging_dir
=
None
):
super
(
QuestionAnsweringTask
,
self
).
__init__
(
params
,
logging_dir
)
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
...
...
@@ -72,6 +72,10 @@ class QuestionAnsweringTask(base_task.Task):
raise
ValueError
(
'Unsupported tokenization method: {}'
.
format
(
params
.
validation_data
.
tokenization
))
if
params
.
validation_data
.
input_path
:
self
.
_tf_record_input_path
,
self
.
_eval_examples
,
self
.
_eval_features
=
(
self
.
_preprocess_eval_data
(
params
.
validation_data
))
def
build_model
(
self
):
if
self
.
_hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
...
...
@@ -107,7 +111,11 @@ class QuestionAnsweringTask(base_task.Task):
is_training
=
False
,
version_2_with_negative
=
params
.
version_2_with_negative
)
temp_file_path
=
params
.
input_preprocessed_data_path
or
'/tmp'
temp_file_path
=
params
.
input_preprocessed_data_path
or
self
.
logging_dir
if
not
temp_file_path
:
raise
ValueError
(
'You must specify a temporary directory, either in '
'params.input_preprocessed_data_path or logging_dir to '
'store intermediate evaluation TFRecord data.'
)
eval_writer
=
self
.
squad_lib
.
FeatureWriter
(
filename
=
os
.
path
.
join
(
temp_file_path
,
'eval.tf_record'
),
is_training
=
False
)
...
...
@@ -168,8 +176,7 @@ class QuestionAnsweringTask(base_task.Task):
if
params
.
is_training
:
input_path
=
params
.
input_path
else
:
input_path
,
self
.
_eval_examples
,
self
.
_eval_features
=
(
self
.
_preprocess_eval_data
(
params
))
input_path
=
self
.
_tf_record_input_path
batch_size
=
input_context
.
get_per_replica_batch_size
(
params
.
global_batch_size
)
if
input_context
else
params
.
global_batch_size
...
...
official/nlp/tasks/sentence_prediction.py
View file @
58d19c67
...
...
@@ -55,8 +55,8 @@ class SentencePredictionConfig(cfg.TaskConfig):
class
SentencePredictionTask
(
base_task
.
Task
):
"""Task object for sentence_prediction."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
):
super
(
SentencePredictionTask
,
self
).
__init__
(
params
)
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
,
logging_dir
=
None
):
super
(
SentencePredictionTask
,
self
).
__init__
(
params
,
logging_dir
)
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
...
...
official/nlp/tasks/tagging.py
View file @
58d19c67
...
...
@@ -75,8 +75,8 @@ def _masked_labels_and_weights(y_true):
class
TaggingTask
(
base_task
.
Task
):
"""Task object for tagging (e.g., NER or POS)."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
):
super
(
TaggingTask
,
self
).
__init__
(
params
)
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
,
logging_dir
=
None
):
super
(
TaggingTask
,
self
).
__init__
(
params
,
logging_dir
)
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment