Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b4b675db
Commit
b4b675db
authored
Dec 03, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 345588520
parent
55c284c8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
174 additions
and
107 deletions
+174
-107
official/nlp/data/wmt_dataloader.py
official/nlp/data/wmt_dataloader.py
+85
-61
official/nlp/data/wmt_dataloader_test.py
official/nlp/data/wmt_dataloader_test.py
+84
-45
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+5
-1
No files found.
official/nlp/data/wmt_dataloader.py
View file @
b4b675db
...
...
@@ -16,12 +16,6 @@
1. Batching scheme
The examples encoded in the TFRecord files contain data in the format:
{'inputs': [variable length array of integers],
'targets': [variable length array of integers]}
Where integers in the arrays refer to tokens in the English and German vocab
file (named `vocab.ende.32768`).
Prior to batching, elements in the dataset are grouped by length (max between
'inputs' and 'targets' length). Each group is then batched such that:
group_batch_size * length <= batch_size.
...
...
@@ -37,32 +31,22 @@
This batching scheme decreases the fraction of padding tokens per training
batch, thus improving the training speed significantly.
"""
from
typing
import
Optional
from
typing
import
Dict
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow_text
as
tftxt
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.nlp.data
import
data_loader
from
official.nlp.data
import
data_loader_factory
# Buffer size for reading records from a TFRecord file. Each training file is
# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
_READ_RECORD_BUFFER
=
8
*
1000
*
1000
# Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY
=
8
_BOUNDARY_SCALE
=
1.1
def
_filter_max_length
(
example
,
max_length
=
256
):
"""Indicates whether the example's length is lower than the maximum length."""
return
tf
.
logical_and
(
tf
.
size
(
example
[
0
])
<=
max_length
,
tf
.
size
(
example
[
1
])
<=
max_length
)
def
_get_example_length
(
example
):
"""Returns the maximum length between the example inputs and targets."""
length
=
tf
.
maximum
(
tf
.
shape
(
example
[
0
])[
0
],
tf
.
shape
(
example
[
1
])[
0
])
...
...
@@ -181,7 +165,11 @@ class WMTDataConfig(cfg.DataConfig):
"""Data config for WMT translation."""
max_seq_length
:
int
=
64
static_batch
:
bool
=
False
vocab_file
:
str
=
''
sentencepiece_model_path
:
str
=
''
src_lang
:
str
=
''
tgt_lang
:
str
=
''
transform_and_batch
:
bool
=
True
has_unique_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
WMTDataConfig
)
...
...
@@ -193,24 +181,20 @@ class WMTDataLoader(data_loader.DataLoader):
self
.
_max_seq_length
=
params
.
max_seq_length
self
.
_static_batch
=
params
.
static_batch
self
.
_global_batch_size
=
params
.
global_batch_size
if
self
.
_params
.
transform_and_batch
:
self
.
_tokenizer
=
tftxt
.
SentencepieceTokenizer
(
model
=
tf
.
io
.
gfile
.
GFile
(
params
.
sentencepiece_model_path
,
'rb'
).
read
(),
add_eos
=
True
)
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
if
self
.
_params
.
is_training
:
name_to_features
=
{
'inputs'
:
tf
.
io
.
Var
LenFeature
(
tf
.
int64
),
'targets'
:
tf
.
io
.
Var
LenFeature
(
tf
.
int64
)
self
.
_params
.
src_lang
:
tf
.
io
.
Fixed
LenFeature
(
[],
tf
.
string
),
self
.
_params
.
tgt_lang
:
tf
.
io
.
Fixed
LenFeature
(
[],
tf
.
string
),
}
if
self
.
_params
.
has_unique_id
:
name_to_features
[
'unique_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
[
'inputs'
]
=
tf
.
sparse
.
to_dense
(
example
[
'inputs'
])
example
[
'targets'
]
=
tf
.
sparse
.
to_dense
(
example
[
'targets'
])
else
:
name_to_features
=
{
'inputs'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'unique_id'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
[
'inputs'
]
=
tf
.
sparse
.
to_dense
(
example
[
'inputs'
])
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
...
...
@@ -220,21 +204,64 @@ class WMTDataLoader(data_loader.DataLoader):
example
[
name
]
=
t
return
example
def
_bucketize_and_batch
(
def
_tokenize
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
tokenized_inputs
=
{}
for
k
,
v
in
inputs
.
items
():
if
k
==
self
.
_params
.
src_lang
:
tokenized_inputs
[
'inputs'
]
=
self
.
_tokenizer
.
tokenize
(
v
)
elif
k
==
self
.
_params
.
tgt_lang
:
tokenized_inputs
[
'targets'
]
=
self
.
_tokenizer
.
tokenize
(
v
)
else
:
tokenized_inputs
[
k
]
=
v
print
(
tokenized_inputs
)
return
tokenized_inputs
def
_filter_max_length
(
self
,
inputs
):
# return tf.constant(True)
return
tf
.
logical_and
(
tf
.
shape
(
inputs
[
'inputs'
])[
0
]
<=
self
.
_max_seq_length
,
tf
.
shape
(
inputs
[
'targets'
])[
0
]
<=
self
.
_max_seq_length
)
def
_maybe_truncate
(
self
,
inputs
):
truncated_inputs
=
{}
for
k
,
v
in
inputs
.
items
():
if
k
==
'inputs'
or
k
==
'targets'
:
truncated_inputs
[
k
]
=
tf
.
pad
(
v
[:
self
.
_max_seq_length
-
1
],
[[
0
,
1
]],
constant_values
=
1
)
if
tf
.
shape
(
v
)[
0
]
>
self
.
_max_seq_length
else
v
else
:
truncated_inputs
[
k
]
=
v
return
truncated_inputs
def
_tokenize_bucketize_and_batch
(
self
,
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
# pylint: disable=g-long-lambda
dataset
=
dataset
.
filter
(
lambda
x
:
_filter_max_length
(
(
x
[
'inputs'
],
x
[
'targets'
]),
self
.
_max_seq_length
))
# pylint: enable=g-long-lambda
dataset
=
dataset
.
map
(
self
.
_tokenize
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
self
.
_params
.
is_training
:
dataset
=
dataset
.
filter
(
self
.
_filter_max_length
)
else
:
dataset
=
dataset
.
map
(
self
.
_maybe_truncate
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
if
self
.
_static_batch
:
padded_shapes
=
dict
([(
name
,
[
self
.
_max_seq_length
])
for
name
,
_
in
dataset
.
element_spec
.
items
()])
padded_shapes
=
{}
for
name
,
_
in
dataset
.
element_spec
.
items
():
if
name
==
'unique_id'
:
padded_shapes
[
name
]
=
[]
else
:
padded_shapes
[
name
]
=
[
self
.
_max_seq_length
]
if
self
.
_static_batch
else
[
None
]
batch_size
=
per_replica_batch_size
if
self
.
_params
.
is_training
:
batch_size
=
int
(
batch_size
//
self
.
_max_seq_length
)
dataset
=
dataset
.
padded_batch
(
int
(
per_replica_batch_size
//
self
.
_max_seq_length
)
,
batch_size
,
padded_shapes
,
drop_remainder
=
True
)
else
:
...
...
@@ -245,27 +272,24 @@ class WMTDataLoader(data_loader.DataLoader):
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
_inference_padded_batch
(
self
,
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
padded_shapes
=
{}
for
name
,
_
in
dataset
.
element_spec
.
items
():
if
name
==
'unique_id'
:
padded_shapes
[
name
]
=
[]
else
:
padded_shapes
[
name
]
=
[
self
.
_max_seq_length
]
if
self
.
_static_batch
else
[
None
]
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
return
dataset
.
padded_batch
(
per_replica_batch_size
,
padded_shapes
,
drop_remainder
=
True
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
decoder_fn
=
None
# Only decode for TFRecords.
if
self
.
_params
.
input_path
:
decoder_fn
=
self
.
_decode
def
_identity
(
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
del
input_context
return
dataset
transform_and_batch_fn
=
_identity
if
self
.
_params
.
transform_and_batch
:
transform_and_batch_fn
=
self
.
_tokenize_bucketize_and_batch
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
transform_and_batch_fn
=
self
.
_bucketize_and_batch
if
self
.
_params
.
is_training
else
self
.
_inference_padded_batch
)
decoder_fn
=
decoder_fn
,
transform_and_batch_fn
=
transform_and_batch_fn
)
return
reader
.
read
(
input_context
)
official/nlp/data/wmt_dataloader_test.py
View file @
b4b675db
...
...
@@ -15,74 +15,113 @@
# ==============================================================================
"""Tests for official.nlp.data.wmt_dataloader."""
import
os
import
random
from
absl
import
logging
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
sentencepiece
import
SentencePieceTrainer
from
official.nlp.data
import
wmt_dataloader
def
_create_fake_dataset
(
output_path
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
_generate_line_file
(
filepath
,
lines
):
with
tf
.
io
.
gfile
.
GFile
(
filepath
,
'w'
)
as
f
:
for
l
in
lines
:
f
.
write
(
'{}
\n
'
.
format
(
l
))
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
for
_
in
range
(
20
):
features
=
{}
seq_length
=
random
.
randint
(
20
,
40
)
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
'inputs'
]
=
create_int_feature
(
input_ids
)
seq_length
=
random
.
randint
(
10
,
80
)
targets
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
'targets'
]
=
create_int_feature
(
targets
)
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
def
_generate_record_file
(
filepath
,
src_lines
,
tgt_lines
,
unique_id
=
False
):
writer
=
tf
.
io
.
TFRecordWriter
(
filepath
)
for
i
,
(
src
,
tgt
)
in
enumerate
(
zip
(
src_lines
,
tgt_lines
)):
features
=
{
'en'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
src
.
encode
()])),
'reverse_en'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
tgt
.
encode
()])),
}
if
unique_id
:
features
[
'unique_id'
]
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
i
])),
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
example
.
SerializeToString
())
writer
.
close
()
class
WMTDataLoaderTest
(
tf
.
test
.
TestCase
):
def
_train_sentencepiece
(
input_path
,
vocab_size
,
model_path
,
eos_id
=
1
):
argstr
=
' '
.
join
([
f
'--input=
{
input_path
}
'
,
f
'--vocab_size=
{
vocab_size
}
'
,
'--character_coverage=0.995'
,
f
'--model_prefix=
{
model_path
}
'
,
'--model_type=bpe'
,
'--bos_id=-1'
,
'--pad_id=0'
,
f
'--eos_id=
{
eos_id
}
'
,
'--unk_id=2'
])
SentencePieceTrainer
.
Train
(
argstr
)
def
test_load_dataset
(
self
):
batch_tokens_size
=
100
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
_create_fake_dataset
(
train_data_path
)
data_config
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
train_data_path
,
max_seq_length
=
35
,
global_batch_size
=
batch_tokens_size
,
is_training
=
True
,
static_batch
=
False
)
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
examples
=
next
(
iter
(
dataset
))
inputs
,
targets
=
examples
[
'inputs'
],
examples
[
'targets'
]
logging
.
info
(
'dynamic inputs=%s targets=%s'
,
inputs
,
targets
)
class
WMTDataLoaderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
WMTDataLoaderTest
,
self
).
setUp
()
self
.
_temp_dir
=
self
.
get_temp_dir
()
src_lines
=
[
'abc ede fg'
,
'bbcd ef a g'
,
'de f a a g'
]
tgt_lines
=
[
'dd cc a ef g'
,
'bcd ef a g'
,
'gef cd ba'
]
self
.
_record_train_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
'train.record'
)
_generate_record_file
(
self
.
_record_train_input_path
,
src_lines
,
tgt_lines
)
self
.
_record_test_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
'test.record'
)
_generate_record_file
(
self
.
_record_test_input_path
,
src_lines
,
tgt_lines
,
unique_id
=
True
)
self
.
_sentencepeice_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
'inputs.txt'
)
_generate_line_file
(
self
.
_sentencepeice_input_path
,
src_lines
+
tgt_lines
)
sentencepeice_model_prefix
=
os
.
path
.
join
(
self
.
_temp_dir
,
'sp'
)
_train_sentencepiece
(
self
.
_sentencepeice_input_path
,
20
,
sentencepeice_model_prefix
)
self
.
_sentencepeice_model_path
=
'{}.model'
.
format
(
sentencepeice_model_prefix
)
@
parameterized
.
named_parameters
(
(
'train_static'
,
True
,
True
,
100
,
(
2
,
35
)),
(
'train_non_static'
,
True
,
False
,
100
,
(
12
,
7
)),
(
'non_train_static'
,
False
,
True
,
3
,
(
3
,
35
)),
(
'non_train_non_static'
,
False
,
False
,
50
,
(
2
,
7
)),)
def
test_load_dataset
(
self
,
is_training
,
static_batch
,
batch_size
,
expected_shape
):
data_config
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
train_data_path
,
input_path
=
self
.
_record_train_input_path
if
is_training
else
self
.
_record_test_input_path
,
max_seq_length
=
35
,
global_batch_size
=
batch_tokens_size
,
is_training
=
True
,
static_batch
=
True
)
global_batch_size
=
batch_size
,
is_training
=
is_training
,
static_batch
=
static_batch
,
src_lang
=
'en'
,
tgt_lang
=
'reverse_en'
,
sentencepiece_model_path
=
self
.
_sentencepeice_model_path
)
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
examples
=
next
(
iter
(
dataset
))
inputs
,
targets
=
examples
[
'inputs'
],
examples
[
'targets'
]
logging
.
info
(
'static inputs=%s targets=%s'
,
inputs
,
targets
)
self
.
assertEqual
(
inputs
.
shape
,
(
2
,
35
))
self
.
assertEqual
(
targets
.
shape
,
(
2
,
35
))
self
.
assertEqual
(
inputs
.
shape
,
expected_shape
)
self
.
assertEqual
(
targets
.
shape
,
expected_shape
)
def
test_load_dataset_raise_invalid_window
(
self
):
batch_tokens_size
=
10
# this is too small to form buckets.
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
_create_fake_dataset
(
train_data_path
)
data_config
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
train_
data
_path
,
input_path
=
self
.
_record_
train_
input
_path
,
max_seq_length
=
100
,
global_batch_size
=
batch_tokens_size
,
is_training
=
True
)
is_training
=
True
,
static_batch
=
False
,
src_lang
=
'en'
,
tgt_lang
=
'reverse_en'
,
sentencepiece_model_path
=
self
.
_sentencepeice_model_path
)
with
self
.
assertRaisesRegex
(
ValueError
,
'The token budget, global batch size, is too small.*'
):
_
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
...
...
official/nlp/modeling/models/seq2seq_transformer.py
View file @
b4b675db
...
...
@@ -53,6 +53,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer
=
None
,
decoder_layer
=
None
,
dtype
=
tf
.
float32
,
eos_id
=
EOS_ID
,
**
kwargs
):
"""Initialize layers to build Transformer model.
...
...
@@ -69,6 +70,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
dtype: float dtype.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
super
(
Seq2SeqTransformer
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -81,6 +83,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self
.
_beam_size
=
beam_size
self
.
_alpha
=
alpha
self
.
_dtype
=
dtype
self
.
_eos_id
=
eos_id
self
.
embedding_lookup
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
self
.
_vocab_size
,
embedding_width
=
self
.
_embedding_width
,
...
...
@@ -102,6 +105,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"padded_decode"
:
self
.
_padded_decode
,
"decode_max_length"
:
self
.
_decode_max_length
,
"dtype"
:
self
.
_dtype
,
"eos_id"
:
self
.
_eos_id
,
"extra_decode_length"
:
self
.
_extra_decode_length
,
"beam_size"
:
self
.
_beam_size
,
"alpha"
:
self
.
_alpha
,
...
...
@@ -226,7 +230,7 @@ class Seq2SeqTransformer(tf.keras.Model):
beam_size
=
self
.
_beam_size
,
alpha
=
self
.
_alpha
,
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
,
eos_id
=
self
.
_eos_id
,
padded_decode
=
self
.
_padded_decode
,
dtype
=
self
.
_dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment