Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5b6be76b
Commit
5b6be76b
authored
Oct 02, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 02, 2020
Browse files
Internal change
PiperOrigin-RevId: 335106919
parent
e1c78a72
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
11 deletions
+39
-11
official/nlp/data/wmt_dataloader.py
official/nlp/data/wmt_dataloader.py
+35
-10
official/nlp/data/wmt_dataloader_test.py
official/nlp/data/wmt_dataloader_test.py
+4
-1
No files found.
official/nlp/data/wmt_dataloader.py
View file @
5b6be76b
...
@@ -182,6 +182,7 @@ class WMTDataConfig(cfg.DataConfig):
...
@@ -182,6 +182,7 @@ class WMTDataConfig(cfg.DataConfig):
"""Data config for WMT translation."""
"""Data config for WMT translation."""
max_seq_length
:
int
=
64
max_seq_length
:
int
=
64
static_batch
:
bool
=
False
static_batch
:
bool
=
False
vocab_file
:
str
=
''
@
data_loader_factory
.
register_data_loader_cls
(
WMTDataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
WMTDataConfig
)
...
@@ -196,6 +197,7 @@ class WMTDataLoader(data_loader.DataLoader):
...
@@ -196,6 +197,7 @@ class WMTDataLoader(data_loader.DataLoader):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
if
self
.
_params
.
is_training
:
name_to_features
=
{
name_to_features
=
{
'inputs'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'inputs'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'targets'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
)
'targets'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
)
...
@@ -203,6 +205,13 @@ class WMTDataLoader(data_loader.DataLoader):
...
@@ -203,6 +205,13 @@ class WMTDataLoader(data_loader.DataLoader):
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
[
'inputs'
]
=
tf
.
sparse
.
to_dense
(
example
[
'inputs'
])
example
[
'inputs'
]
=
tf
.
sparse
.
to_dense
(
example
[
'inputs'
])
example
[
'targets'
]
=
tf
.
sparse
.
to_dense
(
example
[
'targets'
])
example
[
'targets'
]
=
tf
.
sparse
.
to_dense
(
example
[
'targets'
])
else
:
name_to_features
=
{
'inputs'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'unique_id'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
[
'inputs'
]
=
tf
.
sparse
.
to_dense
(
example
[
'inputs'
])
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
# So cast all int64 to int32.
for
name
in
example
:
for
name
in
example
:
...
@@ -224,8 +233,7 @@ class WMTDataLoader(data_loader.DataLoader):
...
@@ -224,8 +233,7 @@ class WMTDataLoader(data_loader.DataLoader):
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
if
self
.
_static_batch
:
if
self
.
_static_batch
:
padded_shapes
=
dict
([(
name
,
[
self
.
_max_seq_length
])
padded_shapes
=
dict
([(
name
,
[
self
.
_max_seq_length
])
for
name
,
_
in
dataset
.
element_spec
.
items
()
for
name
,
_
in
dataset
.
element_spec
.
items
()])
])
dataset
=
dataset
.
padded_batch
(
dataset
=
dataset
.
padded_batch
(
int
(
per_replica_batch_size
//
self
.
_max_seq_length
),
int
(
per_replica_batch_size
//
self
.
_max_seq_length
),
padded_shapes
,
padded_shapes
,
...
@@ -238,10 +246,27 @@ class WMTDataLoader(data_loader.DataLoader):
...
@@ -238,10 +246,27 @@ class WMTDataLoader(data_loader.DataLoader):
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
def
_inference_padded_batch
(
self
,
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
padded_shapes
=
{}
for
name
,
_
in
dataset
.
element_spec
.
items
():
if
name
==
'unique_id'
:
padded_shapes
[
name
]
=
[]
else
:
padded_shapes
[
name
]
=
[
self
.
_max_seq_length
]
if
self
.
_static_batch
else
[
None
]
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
return
dataset
.
padded_batch
(
per_replica_batch_size
,
padded_shapes
,
drop_remainder
=
True
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
decoder_fn
=
self
.
_decode
,
transform_and_batch_fn
=
self
.
_bucketize_and_batch
)
transform_and_batch_fn
=
self
.
_bucketize_and_batch
if
self
.
_params
.
is_training
else
self
.
_inference_padded_batch
)
return
reader
.
read
(
input_context
)
return
reader
.
read
(
input_context
)
official/nlp/data/wmt_dataloader_test.py
View file @
5b6be76b
...
@@ -55,6 +55,7 @@ class WMTDataLoaderTest(tf.test.TestCase):
...
@@ -55,6 +55,7 @@ class WMTDataLoaderTest(tf.test.TestCase):
input_path
=
train_data_path
,
input_path
=
train_data_path
,
max_seq_length
=
35
,
max_seq_length
=
35
,
global_batch_size
=
batch_tokens_size
,
global_batch_size
=
batch_tokens_size
,
is_training
=
True
,
static_batch
=
False
)
static_batch
=
False
)
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
examples
=
next
(
iter
(
dataset
))
examples
=
next
(
iter
(
dataset
))
...
@@ -64,6 +65,7 @@ class WMTDataLoaderTest(tf.test.TestCase):
...
@@ -64,6 +65,7 @@ class WMTDataLoaderTest(tf.test.TestCase):
input_path
=
train_data_path
,
input_path
=
train_data_path
,
max_seq_length
=
35
,
max_seq_length
=
35
,
global_batch_size
=
batch_tokens_size
,
global_batch_size
=
batch_tokens_size
,
is_training
=
True
,
static_batch
=
True
)
static_batch
=
True
)
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
examples
=
next
(
iter
(
dataset
))
examples
=
next
(
iter
(
dataset
))
...
@@ -79,7 +81,8 @@ class WMTDataLoaderTest(tf.test.TestCase):
...
@@ -79,7 +81,8 @@ class WMTDataLoaderTest(tf.test.TestCase):
data_config
=
wmt_dataloader
.
WMTDataConfig
(
data_config
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
train_data_path
,
input_path
=
train_data_path
,
max_seq_length
=
100
,
max_seq_length
=
100
,
global_batch_size
=
batch_tokens_size
)
global_batch_size
=
batch_tokens_size
,
is_training
=
True
)
with
self
.
assertRaisesRegex
(
with
self
.
assertRaisesRegex
(
ValueError
,
'The token budget, global batch size, is too small.*'
):
ValueError
,
'The token budget, global batch size, is too small.*'
):
_
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
_
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment