Commit f599623a authored by Julien Chaumond's avatar Julien Chaumond
Browse files

PreTrainedTokenizerFast: hotfix _convert_encoding

cc @n1t0
parent 16ce15ed
...@@ -1511,14 +1511,16 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): ...@@ -1511,14 +1511,16 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
# Prepare inputs as tensors if asked # Prepare inputs as tensors if asked
if return_tensors == "tf" and is_tf_available(): if return_tensors == "tf" and is_tf_available():
encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]]) encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]]) if "token_type_ids" in encoding_dict:
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict: if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]]) encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]])
elif return_tensors == "pt" and is_torch_available(): elif return_tensors == "pt" and is_torch_available():
encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]]) encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]]) if "token_type_ids" in encoding_dict:
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict: if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]]) encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment