"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "9611c2d0aae7a1a667a3eecaa92756fea1073f20"
Commit c76c3ceb authored by Aditya Bhargava's avatar Aditya Bhargava Committed by Lysandre Debut
Browse files

Add check for token_type_ids before tensorizing

Fix an issue where `prepare_for_model()` gives a `KeyError` when
`return_token_type_ids` is set to `False` and `return_tensors` is
enabled.
parent eb59e9f7
...@@ -1194,14 +1194,18 @@ class PreTrainedTokenizer(object): ...@@ -1194,14 +1194,18 @@ class PreTrainedTokenizer(object):
# Prepare inputs as tensors if asked # Prepare inputs as tensors if asked
if return_tensors == "tf" and is_tf_available(): if return_tensors == "tf" and is_tf_available():
encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]]) encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]])
encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]])
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs: if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]]) encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]])
elif return_tensors == "pt" and is_torch_available(): elif return_tensors == "pt" and is_torch_available():
encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]]) encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]])
encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]])
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs: if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = torch.tensor([encoded_inputs["attention_mask"]]) encoded_inputs["attention_mask"] = torch.tensor([encoded_inputs["attention_mask"]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment