Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a55f69d
Commit
5a55f69d
authored
May 06, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
May 06, 2020
Browse files
Internal change
PiperOrigin-RevId: 310104070
parent
44c3e33f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+17
-13
No files found.
official/nlp/bert/input_pipeline.py
View file @
5a55f69d
...
@@ -41,7 +41,9 @@ def single_file_dataset(input_file, name_to_features):
...
@@ -41,7 +41,9 @@ def single_file_dataset(input_file, name_to_features):
# For training, we want a lot of parallel reading and shuffling.
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
# For eval, we want no shuffling and parallel reading doesn't matter.
d
=
tf
.
data
.
TFRecordDataset
(
input_file
)
d
=
tf
.
data
.
TFRecordDataset
(
input_file
)
d
=
d
.
map
(
lambda
record
:
decode_record
(
record
,
name_to_features
))
d
=
d
.
map
(
lambda
record
:
decode_record
(
record
,
name_to_features
),
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
# When `input_file` is a path to a single file or a list
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# containing a single path, disable auto sharding so that
...
@@ -107,9 +109,13 @@ def create_pretrain_dataset(input_patterns,
...
@@ -107,9 +109,13 @@ def create_pretrain_dataset(input_patterns,
# parallel. You may want to increase this number if you have a large number of
# parallel. You may want to increase this number if you have a large number of
# CPU cores.
# CPU cores.
dataset
=
dataset
.
interleave
(
dataset
=
dataset
.
interleave
(
tf
.
data
.
TFRecordDataset
,
cycle_length
=
8
,
tf
.
data
.
TFRecordDataset
,
cycle_length
=
8
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
is_training
:
dataset
=
dataset
.
shuffle
(
100
)
decode_fn
=
lambda
record
:
decode_record
(
record
,
name_to_features
)
decode_fn
=
lambda
record
:
decode_record
(
record
,
name_to_features
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
...
@@ -136,12 +142,8 @@ def create_pretrain_dataset(input_patterns,
...
@@ -136,12 +142,8 @@ def create_pretrain_dataset(input_patterns,
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
_select_data_from_record
,
_select_data_from_record
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
is_training
:
dataset
=
dataset
.
shuffle
(
100
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
is_training
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
is_training
)
dataset
=
dataset
.
prefetch
(
1024
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
...
@@ -174,14 +176,15 @@ def create_classifier_dataset(file_path,
...
@@ -174,14 +176,15 @@ def create_classifier_dataset(file_path,
y
=
record
[
'label_ids'
]
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
return
(
x
,
y
)
dataset
=
dataset
.
map
(
_select_data_from_record
)
if
is_training
:
if
is_training
:
dataset
=
dataset
.
shuffle
(
100
)
dataset
=
dataset
.
shuffle
(
100
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
_select_data_from_record
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
is_training
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
is_training
)
dataset
=
dataset
.
prefetch
(
1024
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
...
@@ -224,12 +227,13 @@ def create_squad_dataset(file_path,
...
@@ -224,12 +227,13 @@ def create_squad_dataset(file_path,
x
[
name
]
=
tensor
x
[
name
]
=
tensor
return
(
x
,
y
)
return
(
x
,
y
)
dataset
=
dataset
.
map
(
_select_data_from_record
)
if
is_training
:
if
is_training
:
dataset
=
dataset
.
shuffle
(
100
)
dataset
=
dataset
.
shuffle
(
100
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
_select_data_from_record
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
1024
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment