Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3e8f1178
Commit
3e8f1178
authored
Nov 12, 2020
by
Le Hou
Committed by
A. Unique TensorFlower
Nov 12, 2020
Browse files
code clean up.
PiperOrigin-RevId: 342179055
parent
b7cbd12b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
8 deletions
+5
-8
official/core/train_lib.py
official/core/train_lib.py
+1
-2
official/core/train_utils.py
official/core/train_utils.py
+2
-4
official/nlp/train.py
official/nlp/train.py
+1
-1
official/vision/beta/train.py
official/vision/beta/train.py
+1
-1
No files found.
official/core/train_lib.py
View file @
3e8f1178
...
@@ -25,9 +25,9 @@ from absl import logging
...
@@ -25,9 +25,9 @@ from absl import logging
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
train_utils
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.core
import
train_utils
class
BestCheckpointExporter
:
class
BestCheckpointExporter
:
...
@@ -172,7 +172,6 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
...
@@ -172,7 +172,6 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
trainer
=
train_utils
.
create_trainer
(
trainer
=
train_utils
.
create_trainer
(
params
,
params
,
task
,
task
,
model_dir
=
model_dir
,
train
=
'train'
in
mode
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
checkpoint_exporter
=
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
checkpoint_exporter
=
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
...
...
official/core/train_utils.py
View file @
3e8f1178
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
import
json
import
json
import
os
import
os
import
pprint
import
pprint
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
dataclasses
...
@@ -36,10 +36,8 @@ def create_trainer(params: config_definitions.ExperimentConfig,
...
@@ -36,10 +36,8 @@ def create_trainer(params: config_definitions.ExperimentConfig,
task
:
base_task
.
Task
,
task
:
base_task
.
Task
,
train
:
bool
,
train
:
bool
,
evaluate
:
bool
,
evaluate
:
bool
,
checkpoint_exporter
:
Any
=
None
,
checkpoint_exporter
:
Any
=
None
)
->
base_trainer
.
Trainer
:
model_dir
:
Optional
[
str
]
=
None
)
->
base_trainer
.
Trainer
:
"""Create trainer."""
"""Create trainer."""
del
model_dir
logging
.
info
(
'Running default trainer.'
)
logging
.
info
(
'Running default trainer.'
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
optimizer
=
base_trainer
.
create_optimizer
(
params
.
trainer
,
params
.
runtime
)
optimizer
=
base_trainer
.
create_optimizer
(
params
.
trainer
,
params
.
runtime
)
...
...
official/nlp/train.py
View file @
3e8f1178
...
@@ -19,7 +19,6 @@ from absl import app
...
@@ -19,7 +19,6 @@ from absl import app
from
absl
import
flags
from
absl
import
flags
import
gin
import
gin
from
official.core
import
train_utils
from
official.common
import
distribute_utils
from
official.common
import
distribute_utils
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.common
import
registry_imports
...
@@ -27,6 +26,7 @@ from official.common import registry_imports
...
@@ -27,6 +26,7 @@ from official.common import registry_imports
from
official.common
import
flags
as
tfm_flags
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/vision/beta/train.py
View file @
3e8f1178
...
@@ -19,7 +19,6 @@ from absl import app
...
@@ -19,7 +19,6 @@ from absl import app
from
absl
import
flags
from
absl
import
flags
import
gin
import
gin
from
official.core
import
train_utils
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.common
import
registry_imports
# pylint: enable=unused-import
# pylint: enable=unused-import
...
@@ -27,6 +26,7 @@ from official.common import distribute_utils
...
@@ -27,6 +26,7 @@ from official.common import distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment