"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ec25306b39556f52222f50757703047e5a7584c3"
Unverified Commit 2a5c9900 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix RagTokenizer (#10167)

parent c8d3fa0d
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for RAG.""" """Tokenization classes for RAG."""
import os import os
from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
...@@ -28,6 +29,7 @@ class RagTokenizer: ...@@ -28,6 +29,7 @@ class RagTokenizer:
def __init__(self, question_encoder, generator): def __init__(self, question_encoder, generator):
self.question_encoder = question_encoder self.question_encoder = question_encoder
self.generator = generator self.generator = generator
self.current_tokenizer = self.question_encoder
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
...@@ -57,23 +59,60 @@ class RagTokenizer: ...@@ -57,23 +59,60 @@ class RagTokenizer:
return cls(question_encoder=question_encoder, generator=generator) return cls(question_encoder=question_encoder, generator=generator)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.question_encoder(*args, **kwargs) return self.current_tokenizer(*args, **kwargs)
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs) return self.generator.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.generator.decode(*args, **kwargs)
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.current_tokenizer = self.generator
yield
self.current_tokenizer = self.question_encoder
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
if max_length is None: if max_length is None:
max_length = self.question_encoder.model_max_length max_length = self.current_tokenizer.model_max_length
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
with self.as_target_tokenizer():
if max_target_length is None: if max_target_length is None:
max_target_length = self.generator.model_max_length max_target_length = self.current_tokenizer.model_max_length
return super().prepare_seq2seq_batch( labels = self(
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
) )
model_inputs["labels"] = labels["input_ids"]
return model_inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment