Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca574c36
Commit
ca574c36
authored
Aug 19, 2019
by
A. Unique TensorFlower
Browse files
Dataset fixes.
PiperOrigin-RevId: 264174941
parent
3fd8bcfb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
11 deletions
+12
-11
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+12
-11
No files found.
official/transformer/v2/transformer_main.py
View file @
ca574c36
...
...
@@ -171,21 +171,22 @@ class TransformerTask(object):
model
.
summary
()
train_ds
=
data_pipeline
.
train_input_fn
(
params
)
if
self
.
use_tpu
:
if
params
[
"is_tpu_pod"
]:
# Different from experimental_distribute_dataset,
# experimental_distribute_datasets_from_function requires
# per-replica/local batch size.
params
[
"batch_size"
]
/=
self
.
distribution_strategy
.
num_replicas_in_sync
train_ds
=
(
self
.
distribution_strategy
.
experimental_distribute_datasets_from_function
(
lambda
:
data_pipeline
.
train_input_fn
(
params
)))
else
:
train_ds
=
(
self
.
distribution_strategy
.
experimental_distribute_dataset
(
train_ds
)
)
lambda
ctx
:
data_pipeline
.
train_input_fn
(
params
)))
else
:
train_ds
=
data_pipeline
.
train_input_fn
(
params
)
map_data_fn
=
data_pipeline
.
map_data_for_transformer_fn
train_ds
=
train_ds
.
map
(
map_data_fn
,
num_parallel_calls
=
params
[
"num_parallel_calls"
])
if
params
[
"use_ctl"
]:
train_ds_iterator
=
iter
(
train_ds
)
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
0
,
params
)
...
...
@@ -251,7 +252,7 @@ class TransformerTask(object):
flags_obj
.
steps_between_evals
,
dtype
=
tf
.
int32
)
# Runs training steps.
train_steps
(
iter
(
train_ds
)
,
train_steps_per_eval
)
train_steps
(
train_ds
_iterator
,
train_steps_per_eval
)
train_loss
=
train_loss_metric
.
result
().
numpy
().
astype
(
float
)
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
i
*
flags_obj
.
steps_between_evals
,
flags_obj
.
train_steps
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment