Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
26ea4d1a
Commit
26ea4d1a
authored
Apr 15, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Apr 15, 2020
Browse files
Internal change
PiperOrigin-RevId: 306699912
parent
1fffb174
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+2
-2
official/vision/image_classification/dataset_factory.py
official/vision/image_classification/dataset_factory.py
+3
-3
No files found.
official/vision/image_classification/classifier_trainer.py
View file @
26ea4d1a
...
@@ -101,8 +101,8 @@ def get_image_size_from_model(
...
@@ -101,8 +101,8 @@ def get_image_size_from_model(
def
_get_dataset_builders
(
params
:
base_configs
.
ExperimentConfig
,
def
_get_dataset_builders
(
params
:
base_configs
.
ExperimentConfig
,
strategy
:
tf
.
distribute
.
Strategy
,
strategy
:
tf
.
distribute
.
Strategy
,
one_hot
:
bool
one_hot
:
bool
)
->
Tuple
[
Any
,
Any
,
Any
]:
)
->
Tuple
[
Any
,
Any
]:
"""Create and return train
,
validation
, and test
dataset builders."""
"""Create and return train
and
validation dataset builders."""
if
one_hot
:
if
one_hot
:
logging
.
warning
(
'label_smoothing > 0, so datasets will be one hot encoded.'
)
logging
.
warning
(
'label_smoothing > 0, so datasets will be one hot encoded.'
)
else
:
else
:
...
...
official/vision/image_classification/dataset_factory.py
View file @
26ea4d1a
...
@@ -116,7 +116,7 @@ class DatasetConfig(base_config.Config):
...
@@ -116,7 +116,7 @@ class DatasetConfig(base_config.Config):
num_channels
:
Union
[
int
,
str
]
=
'infer'
num_channels
:
Union
[
int
,
str
]
=
'infer'
num_examples
:
Union
[
int
,
str
]
=
'infer'
num_examples
:
Union
[
int
,
str
]
=
'infer'
batch_size
:
int
=
128
batch_size
:
int
=
128
use_per_replica_batch_size
:
bool
=
Fals
e
use_per_replica_batch_size
:
bool
=
Tru
e
num_devices
:
int
=
1
num_devices
:
int
=
1
dtype
:
str
=
'float32'
dtype
:
str
=
'float32'
one_hot
:
bool
=
True
one_hot
:
bool
=
True
...
@@ -185,14 +185,14 @@ class DatasetBuilder:
...
@@ -185,14 +185,14 @@ class DatasetBuilder:
def
batch_size
(
self
)
->
int
:
def
batch_size
(
self
)
->
int
:
"""The batch size, multiplied by the number of replicas (if configured)."""
"""The batch size, multiplied by the number of replicas (if configured)."""
if
self
.
config
.
use_per_replica_batch_size
:
if
self
.
config
.
use_per_replica_batch_size
:
return
self
.
global_
batch_size
return
self
.
config
.
batch_size
*
self
.
config
.
num_devices
else
:
else
:
return
self
.
config
.
batch_size
return
self
.
config
.
batch_size
@
property
@
property
def
global_batch_size
(
self
):
def
global_batch_size
(
self
):
"""The global batch size across all replicas."""
"""The global batch size across all replicas."""
return
self
.
config
.
batch_size
*
self
.
config
.
num_devices
return
self
.
batch_size
@
property
@
property
def
num_steps
(
self
)
->
int
:
def
num_steps
(
self
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment