Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7581884d
Unverified
Commit
7581884d
authored
Aug 19, 2020
by
Suraj Patil
Committed by
GitHub
Aug 19, 2020
Browse files
[BartTokenizerFast] add prepare_seq2seq_batch (#6543)
parent
8bcceace
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
157 additions
and
41 deletions
+157
-41
src/transformers/tokenization_bart.py
src/transformers/tokenization_bart.py
+103
-0
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+54
-41
No files found.
src/transformers/tokenization_bart.py
View file @
7581884d
...
@@ -155,3 +155,106 @@ class BartTokenizerFast(RobertaTokenizerFast):
...
@@ -155,3 +155,106 @@ class BartTokenizerFast(RobertaTokenizerFast):
"vocab_file"
:
{
m
:
vocab_url
for
m
in
_all_bart_models
},
"vocab_file"
:
{
m
:
vocab_url
for
m
in
_all_bart_models
},
"merges_file"
:
{
m
:
merges_url
for
m
in
_all_bart_models
},
"merges_file"
:
{
m
:
merges_url
for
m
in
_all_bart_models
},
}
}
def
prepare_seq2seq_batch
(
self
,
src_texts
:
List
[
str
],
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
max_target_length
:
Optional
[
int
]
=
None
,
padding
:
str
=
"longest"
,
return_tensors
:
str
=
"None"
,
truncation
=
True
,
**
kwargs
,
)
->
BatchEncoding
:
r
"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if
max_length
is
None
:
max_length
=
self
.
model_max_length
model_inputs
:
BatchEncoding
=
self
(
src_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
padding
=
padding
,
truncation
=
truncation
,
**
kwargs
,
)
if
tgt_texts
is
None
:
return
model_inputs
# Process tgt_texts
if
max_target_length
is
None
:
max_target_length
=
max_length
decoder_inputs
:
BatchEncoding
=
self
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
padding
=
padding
,
max_length
=
max_target_length
,
truncation
=
truncation
,
**
kwargs
,
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
return
model_inputs
tests/test_modeling_bart.py
View file @
7581884d
...
@@ -38,6 +38,7 @@ if is_torch_available():
...
@@ -38,6 +38,7 @@ if is_torch_available():
BartForQuestionAnswering
,
BartForQuestionAnswering
,
BartConfig
,
BartConfig
,
BartTokenizer
,
BartTokenizer
,
BartTokenizerFast
,
pipeline
,
pipeline
,
)
)
from
transformers.modeling_bart
import
(
from
transformers.modeling_bart
import
(
...
@@ -421,6 +422,10 @@ class BartModelIntegrationTests(unittest.TestCase):
...
@@ -421,6 +422,10 @@ class BartModelIntegrationTests(unittest.TestCase):
def
default_tokenizer
(
self
):
def
default_tokenizer
(
self
):
return
BartTokenizer
.
from_pretrained
(
"facebook/bart-large"
)
return
BartTokenizer
.
from_pretrained
(
"facebook/bart-large"
)
@
cached_property
def
default_tokenizer_fast
(
self
):
return
BartTokenizerFast
.
from_pretrained
(
"facebook/bart-large"
)
@
slow
@
slow
def
test_inference_no_head
(
self
):
def
test_inference_no_head
(
self
):
model
=
BartModel
.
from_pretrained
(
"facebook/bart-large"
).
to
(
torch_device
)
model
=
BartModel
.
from_pretrained
(
"facebook/bart-large"
).
to
(
torch_device
)
...
@@ -564,74 +569,82 @@ class BartModelIntegrationTests(unittest.TestCase):
...
@@ -564,74 +569,82 @@ class BartModelIntegrationTests(unittest.TestCase):
# TODO(SS): add test case that hits max_length
# TODO(SS): add test case that hits max_length
def
test_prepare_seq2seq_batch
(
self
):
def
test_prepare_seq2seq_batch
(
self
):
tokenizer
=
self
.
default_tokenizer
tokenizer
s
=
[
self
.
default_tokenizer
,
self
.
default_tokenizer_fast
]
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
tgt_text
=
[
"Summary of the text."
,
"Summary of the text."
,
"Another summary."
,
"Another summary."
,
]
]
expected_src_tokens
=
[
0
,
250
,
251
,
17818
,
13
,
32933
,
21645
,
1258
,
4
,
2
]
expected_src_tokens
=
[
0
,
250
,
251
,
17818
,
13
,
32933
,
21645
,
1258
,
4
,
2
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
len
(
expected_src_tokens
),
return_tensors
=
"pt"
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
((
2
,
10
),
batch
.
input_ids
.
shape
)
for
tokenizer
in
tokenizers
:
self
.
assertEqual
((
2
,
10
),
batch
.
attention_mask
.
shape
)
batch
=
tokenizer
.
prepare_seq2seq_batch
(
result
=
batch
.
input_ids
.
tolist
()[
0
]
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
len
(
expected_src_tokens
),
return_tensors
=
"pt"
self
.
assertListEqual
(
expected_src_tokens
,
result
)
)
# Test that special tokens are reset
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
((
2
,
10
),
batch
.
input_ids
.
shape
)
self
.
assertEqual
((
2
,
10
),
batch
.
attention_mask
.
shape
)
result
=
batch
.
input_ids
.
tolist
()[
0
]
self
.
assertListEqual
(
expected_src_tokens
,
result
)
# Test that special tokens are reset
def
test_empty_target_text
(
self
):
def
test_empty_target_text
(
self
):
tokenizer
=
self
.
default_tokenizer
tokenizer
s
=
[
self
.
default_tokenizer
,
self
.
default_tokenizer_fast
]
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
return_tensors
=
"pt"
)
for
tokenizer
in
tokenizers
:
# check if input_ids are returned and no decoder_input_ids
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
return_tensors
=
"pt"
)
self
.
assertIn
(
"input_ids"
,
batch
)
# check if input_ids are returned and no decoder_input_ids
self
.
assertIn
(
"attention_mask"
,
batch
)
self
.
assertIn
(
"input_ids"
,
batch
)
self
.
assertNotIn
(
"decoder_input_ids"
,
batch
)
self
.
assertIn
(
"attention_mask"
,
batch
)
self
.
assertNotIn
(
"decoder_attention_mask"
,
batch
)
self
.
assertNotIn
(
"decoder_input_ids"
,
batch
)
self
.
assertNotIn
(
"decoder_attention_mask"
,
batch
)
def
test_max_target_length
(
self
):
def
test_max_target_length
(
self
):
tokenizer
=
self
.
default_tokenizer
tokenizer
s
=
[
self
.
default_tokenizer
,
self
.
default_tokenizer_fast
]
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
tgt_text
=
[
"Summary of the text."
,
"Summary of the text."
,
"Another summary."
,
"Another summary."
,
]
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
for
tokenizer
in
tokenizers
:
src_text
,
tgt_texts
=
tgt_text
,
max_target_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
"pt"
batch
=
tokenizer
.
prepare_seq2seq_batch
(
)
src_text
,
tgt_texts
=
tgt_text
,
max_target_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
"pt"
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
)
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
# test None max_target_length
# test None max_target_length
batch
=
tokenizer
.
prepare_seq2seq_batch
(
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
"pt"
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
)
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
def
test_outputs_not_longer_than_maxlen
(
self
):
def
test_outputs_not_longer_than_maxlen
(
self
):
tokenizer
=
self
.
default_tokenizer
tokenizer
s
=
[
self
.
default_tokenizer
,
self
.
default_tokenizer_fast
]
batch
=
tokenizer
.
prepare_seq2seq_batch
([
"I am a small frog"
*
1024
,
"I am a small frog"
],
return_tensors
=
"pt"
)
for
tokenizer
in
tokenizers
:
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
batch
=
tokenizer
.
prepare_seq2seq_batch
(
self
.
assertEqual
(
batch
.
input_ids
.
shape
,
(
2
,
1024
))
[
"I am a small frog"
*
1024
,
"I am a small frog"
],
return_tensors
=
"pt"
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
(
batch
.
input_ids
.
shape
,
(
2
,
1024
))
def
test_special_tokens
(
self
):
def
test_special_tokens
(
self
):
tokenizer
=
self
.
default_tokenizer
tokenizer
s
=
[
self
.
default_tokenizer
,
self
.
default_tokenizer_fast
]
src_text
=
[
"A long paragraph for summrization."
]
src_text
=
[
"A long paragraph for summrization."
]
tgt_text
=
[
tgt_text
=
[
"Summary of the text."
,
"Summary of the text."
,
]
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
return_tensors
=
"pt"
)
for
tokenizer
in
tokenizers
:
input_ids
=
batch
[
"input_ids"
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
return_tensors
=
"pt"
)
decoder_input_ids
=
batch
[
"decoder_input_ids"
]
input_ids
=
batch
[
"input_ids"
]
self
.
assertTrue
((
input_ids
[:,
0
]
==
tokenizer
.
bos_token_id
).
all
().
item
())
decoder_input_ids
=
batch
[
"decoder_input_ids"
]
self
.
assertTrue
((
decoder_input_ids
[:,
0
]
==
tokenizer
.
bos_token_id
).
all
().
item
())
self
.
assertTrue
((
input_ids
[:,
0
]
==
tokenizer
.
bos_token_id
).
all
().
item
())
self
.
assertTrue
((
input_ids
[:,
-
1
]
==
tokenizer
.
eos_token_id
).
all
().
item
())
self
.
assertTrue
((
decoder_input_ids
[:,
0
]
==
tokenizer
.
bos_token_id
).
all
().
item
())
self
.
assertTrue
((
decoder_input_ids
[:,
-
1
]
==
tokenizer
.
eos_token_id
).
all
().
item
())
self
.
assertTrue
((
input_ids
[:,
-
1
]
==
tokenizer
.
eos_token_id
).
all
().
item
())
self
.
assertTrue
((
decoder_input_ids
[:,
-
1
]
==
tokenizer
.
eos_token_id
).
all
().
item
())
@
require_torch
@
require_torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment