Commit c76c3ceb authored by Aditya Bhargava's avatar Aditya Bhargava Committed by Lysandre Debut
Browse files

Add check for token_type_ids before tensorizing

Fix an issue where `prepare_for_model()` gives a `KeyError` when
`return_token_type_ids` is set to `False` and `return_tensors` is
enabled.
parent eb59e9f7
...@@ -1194,6 +1194,8 @@ class PreTrainedTokenizer(object): ...@@ -1194,6 +1194,8 @@ class PreTrainedTokenizer(object):
# Prepare inputs as tensors if asked # Prepare inputs as tensors if asked
if return_tensors == "tf" and is_tf_available(): if return_tensors == "tf" and is_tf_available():
encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]]) encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]])
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]]) encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs: if "attention_mask" in encoded_inputs:
...@@ -1201,6 +1203,8 @@ class PreTrainedTokenizer(object): ...@@ -1201,6 +1203,8 @@ class PreTrainedTokenizer(object):
elif return_tensors == "pt" and is_torch_available(): elif return_tensors == "pt" and is_torch_available():
encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]]) encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]])
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]]) encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs: if "attention_mask" in encoded_inputs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment