Unverified Commit a450789d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Disambiguate test for required_input in tokenization base file. (#20731)

* Disambiguate test for required_input in tokenization base file.

* Add test for size
parent 29ff8716
......@@ -24,7 +24,7 @@ import os
import re
import warnings
from collections import OrderedDict, UserDict
from collections.abc import Mapping
from collections.abc import Mapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
......@@ -2940,7 +2940,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
required_input = encoded_inputs[self.model_input_names[0]]
if not required_input:
if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0):
if return_attention_mask:
encoded_inputs["attention_mask"] = []
return encoded_inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment