Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3fb1e20f
Commit
3fb1e20f
authored
May 07, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
May 07, 2020
Browse files
Internal change
PiperOrigin-RevId: 310487163
parent
e9e6d17c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
8 deletions
+8
-8
official/nlp/transformer/data_pipeline.py
official/nlp/transformer/data_pipeline.py
+6
-6
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+2
-2
No files found.
official/nlp/transformer/data_pipeline.py
View file @
3fb1e20f
...
@@ -193,7 +193,7 @@ def _batch_examples(dataset, batch_size, max_length):
...
@@ -193,7 +193,7 @@ def _batch_examples(dataset, batch_size, max_length):
def
_read_and_batch_from_files
(
def
_read_and_batch_from_files
(
file_pattern
,
batch_size
,
max_length
,
num
_parallel
_calls
,
shuffle
,
repeat
,
file_pattern
,
batch_size
,
max_length
,
max_io
_parallel
ism
,
shuffle
,
repeat
,
static_batch
=
False
,
num_replicas
=
1
,
ctx
=
None
):
static_batch
=
False
,
num_replicas
=
1
,
ctx
=
None
):
"""Create dataset where each item is a dict of "inputs" and "targets".
"""Create dataset where each item is a dict of "inputs" and "targets".
...
@@ -201,7 +201,7 @@ def _read_and_batch_from_files(
...
@@ -201,7 +201,7 @@ def _read_and_batch_from_files(
file_pattern: String used to match the input TFRecord files.
file_pattern: String used to match the input TFRecord files.
batch_size: Maximum number of tokens per global batch of examples.
batch_size: Maximum number of tokens per global batch of examples.
max_length: Maximum number of tokens per example
max_length: Maximum number of tokens per example
num
_parallel
_calls: N
umber of cpu cores for parallel input processing.
max_io
_parallel
ism: Max n
umber of cpu cores for parallel input processing.
shuffle: If true, randomizes order of elements.
shuffle: If true, randomizes order of elements.
repeat: Number of times to repeat the dataset. If None, the dataset is
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
repeated forever.
...
@@ -237,13 +237,13 @@ def _read_and_batch_from_files(
...
@@ -237,13 +237,13 @@ def _read_and_batch_from_files(
options
.
experimental_deterministic
=
False
options
.
experimental_deterministic
=
False
dataset
=
dataset
.
interleave
(
dataset
=
dataset
.
interleave
(
_load_records
,
_load_records
,
cycle_length
=
num
_parallel
_calls
,
cycle_length
=
max_io
_parallel
ism
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
).
with_options
(
options
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
).
with_options
(
options
)
# Parse each tf.Example into a dictionary
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
# TODO: Look into prefetch_input_elements for performance optimization.
dataset
=
dataset
.
map
(
_parse_example
,
dataset
=
dataset
.
map
(
_parse_example
,
num_parallel_calls
=
num_parallel_calls
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
# Remove examples where the input or target length exceeds the maximum length,
# Remove examples where the input or target length exceeds the maximum length,
dataset
=
dataset
.
filter
(
lambda
x
,
y
:
_filter_max_length
((
x
,
y
),
max_length
))
dataset
=
dataset
.
filter
(
lambda
x
,
y
:
_filter_max_length
((
x
,
y
),
max_length
))
...
@@ -289,7 +289,7 @@ def train_input_fn(params, ctx=None):
...
@@ -289,7 +289,7 @@ def train_input_fn(params, ctx=None):
return
_generate_synthetic_data
(
params
)
return
_generate_synthetic_data
(
params
)
return
_read_and_batch_from_files
(
return
_read_and_batch_from_files
(
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
params
[
"
num
_parallel
_calls
"
],
shuffle
=
True
,
params
[
"
max_io
_parallel
ism
"
],
shuffle
=
True
,
repeat
=
params
[
"repeat_dataset"
],
static_batch
=
params
[
"static_batch"
],
repeat
=
params
[
"repeat_dataset"
],
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
],
ctx
=
ctx
)
num_replicas
=
params
[
"num_gpus"
],
ctx
=
ctx
)
...
@@ -301,7 +301,7 @@ def eval_input_fn(params, ctx=None):
...
@@ -301,7 +301,7 @@ def eval_input_fn(params, ctx=None):
return
_generate_synthetic_data
(
params
)
return
_generate_synthetic_data
(
params
)
return
_read_and_batch_from_files
(
return
_read_and_batch_from_files
(
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
file_pattern
,
params
[
"batch_size"
],
params
[
"max_length"
],
params
[
"
num
_parallel
_calls
"
],
shuffle
=
False
,
repeat
=
1
,
params
[
"
max_io
_parallel
ism
"
],
shuffle
=
False
,
repeat
=
1
,
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
],
static_batch
=
params
[
"static_batch"
],
num_replicas
=
params
[
"num_gpus"
],
ctx
=
ctx
)
ctx
=
ctx
)
...
...
official/nlp/transformer/transformer_main.py
View file @
3fb1e20f
...
@@ -148,7 +148,7 @@ class TransformerTask(object):
...
@@ -148,7 +148,7 @@ class TransformerTask(object):
params
[
"decode_batch_size"
]
=
flags_obj
.
decode_batch_size
params
[
"decode_batch_size"
]
=
flags_obj
.
decode_batch_size
params
[
"decode_max_length"
]
=
flags_obj
.
decode_max_length
params
[
"decode_max_length"
]
=
flags_obj
.
decode_max_length
params
[
"padded_decode"
]
=
flags_obj
.
padded_decode
params
[
"padded_decode"
]
=
flags_obj
.
padded_decode
params
[
"
num
_parallel
_calls
"
]
=
(
params
[
"
max_io
_parallel
ism
"
]
=
(
flags_obj
.
num_parallel_calls
or
tf
.
data
.
experimental
.
AUTOTUNE
)
flags_obj
.
num_parallel_calls
or
tf
.
data
.
experimental
.
AUTOTUNE
)
params
[
"use_synthetic_data"
]
=
flags_obj
.
use_synthetic_data
params
[
"use_synthetic_data"
]
=
flags_obj
.
use_synthetic_data
...
@@ -239,7 +239,7 @@ class TransformerTask(object):
...
@@ -239,7 +239,7 @@ class TransformerTask(object):
train_ds
=
data_pipeline
.
train_input_fn
(
params
)
train_ds
=
data_pipeline
.
train_input_fn
(
params
)
map_data_fn
=
data_pipeline
.
map_data_for_transformer_fn
map_data_fn
=
data_pipeline
.
map_data_for_transformer_fn
train_ds
=
train_ds
.
map
(
train_ds
=
train_ds
.
map
(
map_data_fn
,
num_parallel_calls
=
params
[
"num_parallel_calls"
]
)
map_data_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
params
[
"use_ctl"
]:
if
params
[
"use_ctl"
]:
train_ds_iterator
=
iter
(
train_ds
)
train_ds_iterator
=
iter
(
train_ds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment