Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eb9e2bfe
Commit
eb9e2bfe
authored
Oct 13, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 13, 2020
Browse files
Internal change
PiperOrigin-RevId: 336960641
parent
08bfa83d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
5 deletions
+7
-5
official/core/base_task.py
official/core/base_task.py
+5
-3
official/utils/testing/mock_task.py
official/utils/testing/mock_task.py
+2
-2
No files found.
official/core/base_task.py
View file @
eb9e2bfe
...
@@ -33,15 +33,17 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
...
@@ -33,15 +33,17 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# Special keys in train/validate step returned logs.
# Special keys in train/validate step returned logs.
loss
=
"loss"
loss
=
"loss"
def
__init__
(
self
,
params
,
logging_dir
:
str
=
None
):
def
__init__
(
self
,
params
,
logging_dir
:
str
=
None
,
name
:
str
=
None
):
"""Task initialization.
"""Task initialization.
Args:
Args:
params: the task configuration instance, which can be any of
params: the task configuration instance, which can be any of
dataclass,
dataclass,
ConfigDict, namedtuple, etc.
ConfigDict, namedtuple, etc.
logging_dir: a string pointing to where the model, summaries etc. will be
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
saved. You can also write additional stuff in this directory.
name: the task name.
"""
"""
super
().
__init__
(
name
=
name
)
self
.
_task_config
=
params
self
.
_task_config
=
params
self
.
_logging_dir
=
logging_dir
self
.
_logging_dir
=
logging_dir
...
...
official/utils/testing/mock_task.py
View file @
eb9e2bfe
...
@@ -46,8 +46,8 @@ class MockTaskConfig(cfg.TaskConfig):
...
@@ -46,8 +46,8 @@ class MockTaskConfig(cfg.TaskConfig):
class
MockTask
(
base_task
.
Task
):
class
MockTask
(
base_task
.
Task
):
"""Mock task object for testing."""
"""Mock task object for testing."""
def
__init__
(
self
,
params
=
None
,
logging_dir
=
None
):
def
__init__
(
self
,
params
=
None
,
logging_dir
=
None
,
name
=
None
):
super
().
__init__
(
params
=
params
,
logging_dir
=
logging_dir
)
super
().
__init__
(
params
=
params
,
logging_dir
=
logging_dir
,
name
=
name
)
def
build_model
(
self
,
*
arg
,
**
kwargs
):
def
build_model
(
self
,
*
arg
,
**
kwargs
):
inputs
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
2
,),
name
=
"random"
,
dtype
=
tf
.
float32
)
inputs
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
2
,),
name
=
"random"
,
dtype
=
tf
.
float32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment