Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2a5c9900
Unverified
Commit
2a5c9900
authored
Feb 15, 2021
by
Suraj Patil
Committed by
GitHub
Feb 15, 2021
Browse files
fix RagTokenizer (#10167)
parent
c8d3fa0d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
6 deletions
+45
-6
src/transformers/models/rag/tokenization_rag.py
src/transformers/models/rag/tokenization_rag.py
+45
-6
No files found.
src/transformers/models/rag/tokenization_rag.py
View file @
2a5c9900
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
"""Tokenization classes for RAG."""
import
os
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
...tokenization_utils_base
import
BatchEncoding
...
...
@@ -28,6 +29,7 @@ class RagTokenizer:
def
__init__
(
self
,
question_encoder
,
generator
):
self
.
question_encoder
=
question_encoder
self
.
generator
=
generator
self
.
current_tokenizer
=
self
.
question_encoder
def
save_pretrained
(
self
,
save_directory
):
if
os
.
path
.
isfile
(
save_directory
):
...
...
@@ -57,23 +59,60 @@ class RagTokenizer:
return
cls
(
question_encoder
=
question_encoder
,
generator
=
generator
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
question_encod
er
(
*
args
,
**
kwargs
)
return
self
.
current_tokeniz
er
(
*
args
,
**
kwargs
)
def
batch_decode
(
self
,
*
args
,
**
kwargs
):
return
self
.
generator
.
batch_decode
(
*
args
,
**
kwargs
)
def
decode
(
self
,
*
args
,
**
kwargs
):
return
self
.
generator
.
decode
(
*
args
,
**
kwargs
)
@
contextmanager
def
as_target_tokenizer
(
self
):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self
.
current_tokenizer
=
self
.
generator
yield
self
.
current_tokenizer
=
self
.
question_encoder
def
prepare_seq2seq_batch
(
self
,
src_texts
:
List
[
str
],
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
max_target_length
:
Optional
[
int
]
=
None
,
padding
:
str
=
"longest"
,
return_tensors
:
str
=
None
,
truncation
:
bool
=
True
,
**
kwargs
,
)
->
BatchEncoding
:
if
max_length
is
None
:
max_length
=
self
.
question_encoder
.
model_max_length
if
max_target_length
is
None
:
max_target_length
=
self
.
generator
.
model_max_length
return
super
().
prepare_seq2seq_batch
(
src_texts
,
tgt_texts
,
max_length
=
max_length
,
max_target_length
=
max_target_length
,
**
kwargs
max_length
=
self
.
current_tokenizer
.
model_max_length
model_inputs
=
self
(
src_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
max_length
=
max_length
,
padding
=
padding
,
truncation
=
truncation
,
**
kwargs
,
)
if
tgt_texts
is
None
:
return
model_inputs
# Process tgt_texts
with
self
.
as_target_tokenizer
():
if
max_target_length
is
None
:
max_target_length
=
self
.
current_tokenizer
.
model_max_length
labels
=
self
(
tgt_texts
,
add_special_tokens
=
True
,
return_tensors
=
return_tensors
,
padding
=
padding
,
max_length
=
max_target_length
,
truncation
=
truncation
,
**
kwargs
,
)
model_inputs
[
"labels"
]
=
labels
[
"input_ids"
]
return
model_inputs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment