Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a565d720
"llm/vscode:/vscode.git/clone" did not exist on "9d91e5e5875e2b2f8605ef15a7da9a616cb05171"
Commit
a565d720
authored
Jul 29, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 323732686
parent
250701c6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
46 deletions
+16
-46
official/core/base_task.py
official/core/base_task.py
+14
-1
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+2
-0
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+0
-16
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+0
-14
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+0
-15
No files found.
official/core/base_task.py
View file @
a565d720
...
@@ -18,6 +18,7 @@ import abc
...
@@ -18,6 +18,7 @@ import abc
import
functools
import
functools
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
,
Callable
,
Optional
from
absl
import
logging
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -67,7 +68,19 @@ class Task(tf.Module):
...
@@ -67,7 +68,19 @@ class Task(tf.Module):
Args:
Args:
model: The keras.Model built or used by this task.
model: The keras.Model built or used by this task.
"""
"""
pass
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
logging
.
info
(
"Trying to load pretrained checkpoint from %s"
,
ckpt_dir_or_file
)
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
"Finished loading pretrained checkpoint from %s"
,
ckpt_dir_or_file
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
...
...
official/modeling/hyperparams/config_definitions.py
View file @
a565d720
...
@@ -179,6 +179,7 @@ class TrainerConfig(base_config.Config):
...
@@ -179,6 +179,7 @@ class TrainerConfig(base_config.Config):
max_to_keep: max checkpoints to keep.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
checkpoints, if set to None, continuous eval will wait indefinitely.
This is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
is used.
...
@@ -205,6 +206,7 @@ class TrainerConfig(base_config.Config):
...
@@ -205,6 +206,7 @@ class TrainerConfig(base_config.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
class
TaskConfig
(
base_config
.
Config
):
init_checkpoint
:
str
=
""
model
:
base_config
.
Config
=
None
model
:
base_config
.
Config
=
None
train_data
:
DataConfig
=
DataConfig
()
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
...
...
official/nlp/tasks/masked_lm.py
View file @
a565d720
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Masked language task."""
"""Masked language task."""
from
absl
import
logging
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -27,7 +26,6 @@ from official.nlp.data import data_loader_factory
...
@@ -27,7 +26,6 @@ from official.nlp.data import data_loader_factory
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MaskedLMConfig
(
cfg
.
TaskConfig
):
class
MaskedLMConfig
(
cfg
.
TaskConfig
):
"""The model config."""
"""The model config."""
init_checkpoint
:
str
=
''
model
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
cls_heads
=
[
model
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
...
@@ -174,17 +172,3 @@ class MaskedLMTask(base_task.Task):
...
@@ -174,17 +172,3 @@ class MaskedLMTask(base_task.Task):
aux_losses
=
model
.
losses
)
aux_losses
=
model
.
losses
)
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
return
{
self
.
loss
:
loss
}
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
# Restoring all modules defined by the model, e.g. encoder, masked_lm and
# cls pooler. The best initialization may vary case by case.
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/question_answering.py
View file @
a565d720
...
@@ -290,17 +290,3 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -290,17 +290,3 @@ class QuestionAnsweringTask(base_task.Task):
eval_metrics
=
{
'exact_match'
:
eval_metrics
[
'exact_match'
],
eval_metrics
=
{
'exact_match'
:
eval_metrics
[
'exact_match'
],
'final_f1'
:
eval_metrics
[
'final_f1'
]}
'final_f1'
:
eval_metrics
[
'final_f1'
]}
return
eval_metrics
return
eval_metrics
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/tagging.py
View file @
a565d720
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Tagging (e.g., NER/POS) task."""
"""Tagging (e.g., NER/POS) task."""
import
logging
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
dataclasses
import
dataclasses
...
@@ -215,20 +214,6 @@ class TaggingTask(base_task.Task):
...
@@ -215,20 +214,6 @@ class TaggingTask(base_task.Task):
seqeval_metrics
.
accuracy_score
(
label_class
,
predict_class
),
seqeval_metrics
.
accuracy_score
(
label_class
,
predict_class
),
}
}
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
predict
(
task
:
TaggingTask
,
params
:
cfg
.
DataConfig
,
def
predict
(
task
:
TaggingTask
,
params
:
cfg
.
DataConfig
,
model
:
tf
.
keras
.
Model
)
->
Tuple
[
List
[
List
[
int
]],
List
[
int
]]:
model
:
tf
.
keras
.
Model
)
->
Tuple
[
List
[
List
[
int
]],
List
[
int
]]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment