Unverified Commit 16469fed authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[PretrainedTokenizer] Factor out tensor conversion method (#3777)

parent 80a16945
...@@ -517,7 +517,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -517,7 +517,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
# Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed. # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
self.padding_side = kwargs.pop("padding_side", self.padding_side) self.padding_side = kwargs.pop("padding_side", self.padding_side)
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
...@@ -1447,34 +1447,37 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -1447,34 +1447,37 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if return_tensors is not None: if return_tensors is not None:
# Do the tensor conversion in batch self.convert_to_tensors_(batch_outputs, return_tensors)
for key, value in batch_outputs.items(): return BatchEncoding(batch_outputs)
if return_tensors == "tf" and is_tf_available():
try: def convert_to_tensors_(self, batch_outputs: dict, return_tensors: str) -> None:
batch_outputs[key] = tf.constant(value) # Do the tensor conversion in batch
except ValueError: for key, value in batch_outputs.items():
if None in [item for sequence in value for item in sequence]: if return_tensors == "tf" and is_tf_available():
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) try:
else: batch_outputs[key] = tf.constant(value)
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) except ValueError:
elif return_tensors == "pt" and is_torch_available(): if None in [item for sequence in value for item in sequence]:
try: raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
batch_outputs[key] = torch.tensor(value) else:
except ValueError:
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
except RuntimeError: elif return_tensors == "pt" and is_torch_available():
if None in [item for sequence in value for item in sequence]: try:
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) batch_outputs[key] = torch.tensor(value)
else: except ValueError:
raise raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
elif return_tensors is not None: except RuntimeError:
logger.warning( if None in [item for sequence in value for item in sequence]:
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
return_tensors else:
) raise
)
return BatchEncoding(batch_outputs) elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
def prepare_for_model( def prepare_for_model(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment