Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2a77813d
Unverified
Commit
2a77813d
authored
Aug 17, 2020
by
Suraj Patil
Committed by
GitHub
Aug 17, 2020
Browse files
[BartTokenizer] add prepare s2s batch (#6212)
Co-authored-by:
sgugger
<
sylvain.gugger@gmail.com
>
parent
84d33317
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
179 additions
and
1 deletion
+179
-1
src/transformers/tokenization_bart.py
src/transformers/tokenization_bart.py
+103
-0
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+76
-1
No files found.
src/transformers/tokenization_bart.py
View file @
2a77813d
...
@@ -42,6 +42,109 @@ class BartTokenizer(RobertaTokenizer):
...
@@ -42,6 +42,109 @@ class BartTokenizer(RobertaTokenizer):
"merges_file"
:
{
m
:
merges_url
for
m
in
_all_bart_models
},
"merges_file"
:
{
m
:
merges_url
for
m
in
_all_bart_models
},
}
}
def
prepare_seq2seq_batch
(
self
,
src_texts
:
List
[
str
],
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
max_target_length
:
Optional
[
int
]
=
None
,
padding
:
str
=
"longest"
,
return_tensors
:
str
=
"None"
,
truncation
=
True
,
**
kwargs
,
)
->
BatchEncoding
:
r
"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if
max_length
is
None
:
max_length
=
self
.
model_max_length
model_inputs
:
BatchEncoding
=
self
(
src_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
padding
=
padding
,
truncation
=
truncation
,
**
kwargs
,
)
if
tgt_texts
is
None
:
return
model_inputs
# Process tgt_texts
if
max_target_length
is
None
:
max_target_length
=
max_length
decoder_inputs
:
BatchEncoding
=
self
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
padding
=
padding
,
max_length
=
max_target_length
,
truncation
=
truncation
,
**
kwargs
,
)
for
k
,
v
in
decoder_inputs
.
items
():
model_inputs
[
f
"decoder_
{
k
}
"
]
=
v
return
model_inputs
class
BartTokenizerFast
(
RobertaTokenizerFast
):
class
BartTokenizerFast
(
RobertaTokenizerFast
):
# merges and vocab same as Roberta
# merges and vocab same as Roberta
...
...
tests/test_modeling_bart.py
View file @
2a77813d
...
@@ -18,7 +18,8 @@ import unittest
...
@@ -18,7 +18,8 @@ import unittest
import
timeout_decorator
# noqa
import
timeout_decorator
# noqa
from
transformers
import
is_torch_available
from
transformers
import
BatchEncoding
,
is_torch_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
...
@@ -416,6 +417,10 @@ TOLERANCE = 1e-4
...
@@ -416,6 +417,10 @@ TOLERANCE = 1e-4
@
require_torch
@
require_torch
class
BartModelIntegrationTests
(
unittest
.
TestCase
):
class
BartModelIntegrationTests
(
unittest
.
TestCase
):
@
cached_property
def
default_tokenizer
(
self
):
return
BartTokenizer
.
from_pretrained
(
"facebook/bart-large"
)
@
slow
@
slow
def
test_inference_no_head
(
self
):
def
test_inference_no_head
(
self
):
model
=
BartModel
.
from_pretrained
(
"facebook/bart-large"
).
to
(
torch_device
)
model
=
BartModel
.
from_pretrained
(
"facebook/bart-large"
).
to
(
torch_device
)
...
@@ -559,6 +564,76 @@ class BartModelIntegrationTests(unittest.TestCase):
...
@@ -559,6 +564,76 @@ class BartModelIntegrationTests(unittest.TestCase):
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length
# TODO(SS): add test case that hits max_length
def
test_prepare_seq2seq_batch
(
self
):
tokenizer
=
self
.
default_tokenizer
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
"Summary of the text."
,
"Another summary."
,
]
expected_src_tokens
=
[
0
,
250
,
251
,
17818
,
13
,
32933
,
21645
,
1258
,
4
,
2
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
len
(
expected_src_tokens
),
return_tensors
=
"pt"
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
((
2
,
10
),
batch
.
input_ids
.
shape
)
self
.
assertEqual
((
2
,
10
),
batch
.
attention_mask
.
shape
)
result
=
batch
.
input_ids
.
tolist
()[
0
]
self
.
assertListEqual
(
expected_src_tokens
,
result
)
# Test that special tokens are reset
def
test_empty_target_text
(
self
):
tokenizer
=
self
.
default_tokenizer
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
return_tensors
=
"pt"
)
# check if input_ids are returned and no decoder_input_ids
self
.
assertIn
(
"input_ids"
,
batch
)
self
.
assertIn
(
"attention_mask"
,
batch
)
self
.
assertNotIn
(
"decoder_input_ids"
,
batch
)
self
.
assertNotIn
(
"decoder_attention_mask"
,
batch
)
def
test_max_target_length
(
self
):
tokenizer
=
self
.
default_tokenizer
src_text
=
[
"A long paragraph for summrization."
,
"Another paragraph for summrization."
]
tgt_text
=
[
"Summary of the text."
,
"Another summary."
,
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_target_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
# test None max_target_length
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
max_length
=
32
,
padding
=
"max_length"
,
return_tensors
=
"pt"
)
self
.
assertEqual
(
32
,
batch
[
"decoder_input_ids"
].
shape
[
1
])
self
.
assertEqual
(
32
,
batch
[
"decoder_attention_mask"
].
shape
[
1
])
def
test_outputs_not_longer_than_maxlen
(
self
):
tokenizer
=
self
.
default_tokenizer
batch
=
tokenizer
.
prepare_seq2seq_batch
([
"I am a small frog"
*
1024
,
"I am a small frog"
],
return_tensors
=
"pt"
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertEqual
(
batch
.
input_ids
.
shape
,
(
2
,
1024
))
def
test_special_tokens
(
self
):
tokenizer
=
self
.
default_tokenizer
src_text
=
[
"A long paragraph for summrization."
]
tgt_text
=
[
"Summary of the text."
,
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
return_tensors
=
"pt"
)
input_ids
=
batch
[
"input_ids"
]
decoder_input_ids
=
batch
[
"decoder_input_ids"
]
self
.
assertTrue
((
input_ids
[:,
0
]
==
tokenizer
.
bos_token_id
).
all
().
item
())
self
.
assertTrue
((
decoder_input_ids
[:,
0
]
==
tokenizer
.
bos_token_id
).
all
().
item
())
self
.
assertTrue
((
input_ids
[:,
-
1
]
==
tokenizer
.
eos_token_id
).
all
().
item
())
self
.
assertTrue
((
decoder_input_ids
[:,
-
1
]
==
tokenizer
.
eos_token_id
).
all
().
item
())
@
require_torch
@
require_torch
class
TestSinusoidalPositionalEmbeddings
(
unittest
.
TestCase
):
class
TestSinusoidalPositionalEmbeddings
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment