Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b4b675db
Commit
b4b675db
authored
Dec 03, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 345588520
parent
55c284c8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
174 additions
and
107 deletions
+174
-107
official/nlp/data/wmt_dataloader.py
official/nlp/data/wmt_dataloader.py
+85
-61
official/nlp/data/wmt_dataloader_test.py
official/nlp/data/wmt_dataloader_test.py
+84
-45
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+5
-1
No files found.
official/nlp/data/wmt_dataloader.py
View file @
b4b675db
...
@@ -16,12 +16,6 @@
...
@@ -16,12 +16,6 @@
1. Batching scheme
1. Batching scheme
The examples encoded in the TFRecord files contain data in the format:
{'inputs': [variable length array of integers],
'targets': [variable length array of integers]}
Where integers in the arrays refer to tokens in the English and German vocab
file (named `vocab.ende.32768`).
Prior to batching, elements in the dataset are grouped by length (max between
Prior to batching, elements in the dataset are grouped by length (max between
'inputs' and 'targets' length). Each group is then batched such that:
'inputs' and 'targets' length). Each group is then batched such that:
group_batch_size * length <= batch_size.
group_batch_size * length <= batch_size.
...
@@ -37,32 +31,22 @@
...
@@ -37,32 +31,22 @@
This batching scheme decreases the fraction of padding tokens per training
This batching scheme decreases the fraction of padding tokens per training
batch, thus improving the training speed significantly.
batch, thus improving the training speed significantly.
"""
"""
from
typing
import
Optional
from
typing
import
Dict
,
Optional
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_text
as
tftxt
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.nlp.data
import
data_loader
from
official.nlp.data
import
data_loader
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
data_loader_factory
# Buffer size for reading records from a TFRecord file. Each training file is
# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
_READ_RECORD_BUFFER
=
8
*
1000
*
1000
# Example grouping constants. Defines length boundaries for each group.
# Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor.
# These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY
=
8
_MIN_BOUNDARY
=
8
_BOUNDARY_SCALE
=
1.1
_BOUNDARY_SCALE
=
1.1
def
_filter_max_length
(
example
,
max_length
=
256
):
"""Indicates whether the example's length is lower than the maximum length."""
return
tf
.
logical_and
(
tf
.
size
(
example
[
0
])
<=
max_length
,
tf
.
size
(
example
[
1
])
<=
max_length
)
def
_get_example_length
(
example
):
def
_get_example_length
(
example
):
"""Returns the maximum length between the example inputs and targets."""
"""Returns the maximum length between the example inputs and targets."""
length
=
tf
.
maximum
(
tf
.
shape
(
example
[
0
])[
0
],
tf
.
shape
(
example
[
1
])[
0
])
length
=
tf
.
maximum
(
tf
.
shape
(
example
[
0
])[
0
],
tf
.
shape
(
example
[
1
])[
0
])
...
@@ -181,7 +165,11 @@ class WMTDataConfig(cfg.DataConfig):
...
@@ -181,7 +165,11 @@ class WMTDataConfig(cfg.DataConfig):
"""Data config for WMT translation."""
"""Data config for WMT translation."""
max_seq_length
:
int
=
64
max_seq_length
:
int
=
64
static_batch
:
bool
=
False
static_batch
:
bool
=
False
vocab_file
:
str
=
''
sentencepiece_model_path
:
str
=
''
src_lang
:
str
=
''
tgt_lang
:
str
=
''
transform_and_batch
:
bool
=
True
has_unique_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
WMTDataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
WMTDataConfig
)
...
@@ -193,24 +181,20 @@ class WMTDataLoader(data_loader.DataLoader):
...
@@ -193,24 +181,20 @@ class WMTDataLoader(data_loader.DataLoader):
self
.
_max_seq_length
=
params
.
max_seq_length
self
.
_max_seq_length
=
params
.
max_seq_length
self
.
_static_batch
=
params
.
static_batch
self
.
_static_batch
=
params
.
static_batch
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_global_batch_size
=
params
.
global_batch_size
if
self
.
_params
.
transform_and_batch
:
self
.
_tokenizer
=
tftxt
.
SentencepieceTokenizer
(
model
=
tf
.
io
.
gfile
.
GFile
(
params
.
sentencepiece_model_path
,
'rb'
).
read
(),
add_eos
=
True
)
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
if
self
.
_params
.
is_training
:
name_to_features
=
{
name_to_features
=
{
self
.
_params
.
src_lang
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
),
'inputs'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
self
.
_params
.
tgt_lang
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
),
'targets'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
)
}
}
if
self
.
_params
.
has_unique_id
:
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
name_to_features
[
'unique_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
[
'inputs'
]
=
tf
.
sparse
.
to_dense
(
example
[
'inputs'
])
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
[
'targets'
]
=
tf
.
sparse
.
to_dense
(
example
[
'targets'
])
else
:
name_to_features
=
{
'inputs'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'unique_id'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
[
'inputs'
]
=
tf
.
sparse
.
to_dense
(
example
[
'inputs'
])
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
# So cast all int64 to int32.
for
name
in
example
:
for
name
in
example
:
...
@@ -220,21 +204,64 @@ class WMTDataLoader(data_loader.DataLoader):
...
@@ -220,21 +204,64 @@ class WMTDataLoader(data_loader.DataLoader):
example
[
name
]
=
t
example
[
name
]
=
t
return
example
return
example
def
_bucketize_and_batch
(
def
_tokenize
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
tokenized_inputs
=
{}
for
k
,
v
in
inputs
.
items
():
if
k
==
self
.
_params
.
src_lang
:
tokenized_inputs
[
'inputs'
]
=
self
.
_tokenizer
.
tokenize
(
v
)
elif
k
==
self
.
_params
.
tgt_lang
:
tokenized_inputs
[
'targets'
]
=
self
.
_tokenizer
.
tokenize
(
v
)
else
:
tokenized_inputs
[
k
]
=
v
print
(
tokenized_inputs
)
return
tokenized_inputs
def
_filter_max_length
(
self
,
inputs
):
# return tf.constant(True)
return
tf
.
logical_and
(
tf
.
shape
(
inputs
[
'inputs'
])[
0
]
<=
self
.
_max_seq_length
,
tf
.
shape
(
inputs
[
'targets'
])[
0
]
<=
self
.
_max_seq_length
)
def
_maybe_truncate
(
self
,
inputs
):
truncated_inputs
=
{}
for
k
,
v
in
inputs
.
items
():
if
k
==
'inputs'
or
k
==
'targets'
:
truncated_inputs
[
k
]
=
tf
.
pad
(
v
[:
self
.
_max_seq_length
-
1
],
[[
0
,
1
]],
constant_values
=
1
)
if
tf
.
shape
(
v
)[
0
]
>
self
.
_max_seq_length
else
v
else
:
truncated_inputs
[
k
]
=
v
return
truncated_inputs
def
_tokenize_bucketize_and_batch
(
self
,
self
,
dataset
,
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
# pylint: disable=g-long-lambda
dataset
=
dataset
.
map
(
dataset
=
dataset
.
filter
(
lambda
x
:
_filter_max_length
(
self
.
_tokenize
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
(
x
[
'inputs'
],
x
[
'targets'
]),
self
.
_max_seq_length
))
# pylint: enable=g-long-lambda
if
self
.
_params
.
is_training
:
dataset
=
dataset
.
filter
(
self
.
_filter_max_length
)
else
:
dataset
=
dataset
.
map
(
self
.
_maybe_truncate
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
if
self
.
_static_batch
:
if
self
.
_static_batch
:
padded_shapes
=
dict
([(
name
,
[
self
.
_max_seq_length
])
padded_shapes
=
{}
for
name
,
_
in
dataset
.
element_spec
.
items
()])
for
name
,
_
in
dataset
.
element_spec
.
items
():
if
name
==
'unique_id'
:
padded_shapes
[
name
]
=
[]
else
:
padded_shapes
[
name
]
=
[
self
.
_max_seq_length
]
if
self
.
_static_batch
else
[
None
]
batch_size
=
per_replica_batch_size
if
self
.
_params
.
is_training
:
batch_size
=
int
(
batch_size
//
self
.
_max_seq_length
)
dataset
=
dataset
.
padded_batch
(
dataset
=
dataset
.
padded_batch
(
int
(
per_replica_batch_size
//
self
.
_max_seq_length
)
,
batch_size
,
padded_shapes
,
padded_shapes
,
drop_remainder
=
True
)
drop_remainder
=
True
)
else
:
else
:
...
@@ -245,27 +272,24 @@ class WMTDataLoader(data_loader.DataLoader):
...
@@ -245,27 +272,24 @@ class WMTDataLoader(data_loader.DataLoader):
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
def
_inference_padded_batch
(
self
,
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
padded_shapes
=
{}
for
name
,
_
in
dataset
.
element_spec
.
items
():
if
name
==
'unique_id'
:
padded_shapes
[
name
]
=
[]
else
:
padded_shapes
[
name
]
=
[
self
.
_max_seq_length
]
if
self
.
_static_batch
else
[
None
]
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
return
dataset
.
padded_batch
(
per_replica_batch_size
,
padded_shapes
,
drop_remainder
=
True
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
"""Returns a tf.dataset.Dataset."""
decoder_fn
=
None
# Only decode for TFRecords.
if
self
.
_params
.
input_path
:
decoder_fn
=
self
.
_decode
def
_identity
(
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
del
input_context
return
dataset
transform_and_batch_fn
=
_identity
if
self
.
_params
.
transform_and_batch
:
transform_and_batch_fn
=
self
.
_tokenize_bucketize_and_batch
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
decoder_fn
=
decoder_fn
,
transform_and_batch_fn
=
self
.
_bucketize_and_batch
transform_and_batch_fn
=
transform_and_batch_fn
)
if
self
.
_params
.
is_training
else
self
.
_inference_padded_batch
)
return
reader
.
read
(
input_context
)
return
reader
.
read
(
input_context
)
official/nlp/data/wmt_dataloader_test.py
View file @
b4b675db
...
@@ -15,74 +15,113 @@
...
@@ -15,74 +15,113 @@
# ==============================================================================
# ==============================================================================
"""Tests for official.nlp.data.wmt_dataloader."""
"""Tests for official.nlp.data.wmt_dataloader."""
import
os
import
os
import
random
from
absl.testing
import
parameterized
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
sentencepiece
import
SentencePieceTrainer
from
official.nlp.data
import
wmt_dataloader
from
official.nlp.data
import
wmt_dataloader
def
_create_fake_dataset
(
output_path
):
def
_generate_line_file
(
filepath
,
lines
):
"""Creates a fake dataset."""
with
tf
.
io
.
gfile
.
GFile
(
filepath
,
'w'
)
as
f
:
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
for
l
in
lines
:
f
.
write
(
'{}
\n
'
.
format
(
l
))
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
for
_
in
range
(
20
):
def
_generate_record_file
(
filepath
,
src_lines
,
tgt_lines
,
unique_id
=
False
):
features
=
{}
writer
=
tf
.
io
.
TFRecordWriter
(
filepath
)
seq_length
=
random
.
randint
(
20
,
40
)
for
i
,
(
src
,
tgt
)
in
enumerate
(
zip
(
src_lines
,
tgt_lines
)):
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
=
{
features
[
'inputs'
]
=
create_int_feature
(
input_ids
)
'en'
:
tf
.
train
.
Feature
(
seq_length
=
random
.
randint
(
10
,
80
)
bytes_list
=
tf
.
train
.
BytesList
(
targets
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
value
=
[
src
.
encode
()])),
features
[
'targets'
]
=
create_int_feature
(
targets
)
'reverse_en'
:
tf
.
train
.
Feature
(
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
bytes_list
=
tf
.
train
.
BytesList
(
writer
.
write
(
tf_example
.
SerializeToString
())
value
=
[
tgt
.
encode
()])),
}
if
unique_id
:
features
[
'unique_id'
]
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
i
])),
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
example
.
SerializeToString
())
writer
.
close
()
writer
.
close
()
class
WMTDataLoaderTest
(
tf
.
test
.
TestCase
):
def
_train_sentencepiece
(
input_path
,
vocab_size
,
model_path
,
eos_id
=
1
):
argstr
=
' '
.
join
([
f
'--input=
{
input_path
}
'
,
f
'--vocab_size=
{
vocab_size
}
'
,
'--character_coverage=0.995'
,
f
'--model_prefix=
{
model_path
}
'
,
'--model_type=bpe'
,
'--bos_id=-1'
,
'--pad_id=0'
,
f
'--eos_id=
{
eos_id
}
'
,
'--unk_id=2'
])
SentencePieceTrainer
.
Train
(
argstr
)
def
test_load_dataset
(
self
):
batch_tokens_size
=
100
class
WMTDataLoaderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
_create_fake_dataset
(
train_data_path
)
def
setUp
(
self
):
data_config
=
wmt_dataloader
.
WMTDataConfig
(
super
(
WMTDataLoaderTest
,
self
).
setUp
()
input_path
=
train_data_path
,
self
.
_temp_dir
=
self
.
get_temp_dir
()
max_seq_length
=
35
,
src_lines
=
[
global_batch_size
=
batch_tokens_size
,
'abc ede fg'
,
is_training
=
True
,
'bbcd ef a g'
,
static_batch
=
False
)
'de f a a g'
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
]
examples
=
next
(
iter
(
dataset
))
tgt_lines
=
[
inputs
,
targets
=
examples
[
'inputs'
],
examples
[
'targets'
]
'dd cc a ef g'
,
logging
.
info
(
'dynamic inputs=%s targets=%s'
,
inputs
,
targets
)
'bcd ef a g'
,
'gef cd ba'
]
self
.
_record_train_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
'train.record'
)
_generate_record_file
(
self
.
_record_train_input_path
,
src_lines
,
tgt_lines
)
self
.
_record_test_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
'test.record'
)
_generate_record_file
(
self
.
_record_test_input_path
,
src_lines
,
tgt_lines
,
unique_id
=
True
)
self
.
_sentencepeice_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
'inputs.txt'
)
_generate_line_file
(
self
.
_sentencepeice_input_path
,
src_lines
+
tgt_lines
)
sentencepeice_model_prefix
=
os
.
path
.
join
(
self
.
_temp_dir
,
'sp'
)
_train_sentencepiece
(
self
.
_sentencepeice_input_path
,
20
,
sentencepeice_model_prefix
)
self
.
_sentencepeice_model_path
=
'{}.model'
.
format
(
sentencepeice_model_prefix
)
@
parameterized
.
named_parameters
(
(
'train_static'
,
True
,
True
,
100
,
(
2
,
35
)),
(
'train_non_static'
,
True
,
False
,
100
,
(
12
,
7
)),
(
'non_train_static'
,
False
,
True
,
3
,
(
3
,
35
)),
(
'non_train_non_static'
,
False
,
False
,
50
,
(
2
,
7
)),)
def
test_load_dataset
(
self
,
is_training
,
static_batch
,
batch_size
,
expected_shape
):
data_config
=
wmt_dataloader
.
WMTDataConfig
(
data_config
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
train_data_path
,
input_path
=
self
.
_record_train_input_path
if
is_training
else
self
.
_record_test_input_path
,
max_seq_length
=
35
,
max_seq_length
=
35
,
global_batch_size
=
batch_tokens_size
,
global_batch_size
=
batch_size
,
is_training
=
True
,
is_training
=
is_training
,
static_batch
=
True
)
static_batch
=
static_batch
,
src_lang
=
'en'
,
tgt_lang
=
'reverse_en'
,
sentencepiece_model_path
=
self
.
_sentencepeice_model_path
)
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
dataset
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
examples
=
next
(
iter
(
dataset
))
examples
=
next
(
iter
(
dataset
))
inputs
,
targets
=
examples
[
'inputs'
],
examples
[
'targets'
]
inputs
,
targets
=
examples
[
'inputs'
],
examples
[
'targets'
]
logging
.
info
(
'static inputs=%s targets=%s'
,
inputs
,
targets
)
self
.
assertEqual
(
inputs
.
shape
,
expected_shape
)
self
.
assertEqual
(
inputs
.
shape
,
(
2
,
35
))
self
.
assertEqual
(
targets
.
shape
,
expected_shape
)
self
.
assertEqual
(
targets
.
shape
,
(
2
,
35
))
def
test_load_dataset_raise_invalid_window
(
self
):
def
test_load_dataset_raise_invalid_window
(
self
):
batch_tokens_size
=
10
# this is too small to form buckets.
batch_tokens_size
=
10
# this is too small to form buckets.
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
_create_fake_dataset
(
train_data_path
)
data_config
=
wmt_dataloader
.
WMTDataConfig
(
data_config
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
train_
data
_path
,
input_path
=
self
.
_record_
train_
input
_path
,
max_seq_length
=
100
,
max_seq_length
=
100
,
global_batch_size
=
batch_tokens_size
,
global_batch_size
=
batch_tokens_size
,
is_training
=
True
)
is_training
=
True
,
static_batch
=
False
,
src_lang
=
'en'
,
tgt_lang
=
'reverse_en'
,
sentencepiece_model_path
=
self
.
_sentencepeice_model_path
)
with
self
.
assertRaisesRegex
(
with
self
.
assertRaisesRegex
(
ValueError
,
'The token budget, global batch size, is too small.*'
):
ValueError
,
'The token budget, global batch size, is too small.*'
):
_
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
_
=
wmt_dataloader
.
WMTDataLoader
(
data_config
).
load
()
...
...
official/nlp/modeling/models/seq2seq_transformer.py
View file @
b4b675db
...
@@ -53,6 +53,7 @@ class Seq2SeqTransformer(tf.keras.Model):
...
@@ -53,6 +53,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer
=
None
,
encoder_layer
=
None
,
decoder_layer
=
None
,
decoder_layer
=
None
,
dtype
=
tf
.
float32
,
dtype
=
tf
.
float32
,
eos_id
=
EOS_ID
,
**
kwargs
):
**
kwargs
):
"""Initialize layers to build Transformer model.
"""Initialize layers to build Transformer model.
...
@@ -69,6 +70,7 @@ class Seq2SeqTransformer(tf.keras.Model):
...
@@ -69,6 +70,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer: An initialized encoder layer.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
decoder_layer: An initialized decoder layer.
dtype: float dtype.
dtype: float dtype.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
**kwargs: other keyword arguments.
"""
"""
super
(
Seq2SeqTransformer
,
self
).
__init__
(
**
kwargs
)
super
(
Seq2SeqTransformer
,
self
).
__init__
(
**
kwargs
)
...
@@ -81,6 +83,7 @@ class Seq2SeqTransformer(tf.keras.Model):
...
@@ -81,6 +83,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self
.
_beam_size
=
beam_size
self
.
_beam_size
=
beam_size
self
.
_alpha
=
alpha
self
.
_alpha
=
alpha
self
.
_dtype
=
dtype
self
.
_dtype
=
dtype
self
.
_eos_id
=
eos_id
self
.
embedding_lookup
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
self
.
embedding_lookup
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
self
.
_vocab_size
,
vocab_size
=
self
.
_vocab_size
,
embedding_width
=
self
.
_embedding_width
,
embedding_width
=
self
.
_embedding_width
,
...
@@ -102,6 +105,7 @@ class Seq2SeqTransformer(tf.keras.Model):
...
@@ -102,6 +105,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"padded_decode"
:
self
.
_padded_decode
,
"padded_decode"
:
self
.
_padded_decode
,
"decode_max_length"
:
self
.
_decode_max_length
,
"decode_max_length"
:
self
.
_decode_max_length
,
"dtype"
:
self
.
_dtype
,
"dtype"
:
self
.
_dtype
,
"eos_id"
:
self
.
_eos_id
,
"extra_decode_length"
:
self
.
_extra_decode_length
,
"extra_decode_length"
:
self
.
_extra_decode_length
,
"beam_size"
:
self
.
_beam_size
,
"beam_size"
:
self
.
_beam_size
,
"alpha"
:
self
.
_alpha
,
"alpha"
:
self
.
_alpha
,
...
@@ -226,7 +230,7 @@ class Seq2SeqTransformer(tf.keras.Model):
...
@@ -226,7 +230,7 @@ class Seq2SeqTransformer(tf.keras.Model):
beam_size
=
self
.
_beam_size
,
beam_size
=
self
.
_beam_size
,
alpha
=
self
.
_alpha
,
alpha
=
self
.
_alpha
,
max_decode_length
=
max_decode_length
,
max_decode_length
=
max_decode_length
,
eos_id
=
EOS_ID
,
eos_id
=
self
.
_eos_id
,
padded_decode
=
self
.
_padded_decode
,
padded_decode
=
self
.
_padded_decode
,
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment