Unverified Commit 18df4407 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Replace dict/BatchEncoding instance checks by Mapping (#17014)

* Replace dict/BatchEncoding instance checks by Mapping

* Typo
parent b8dffd1f
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
import random import random
import warnings import warnings
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from ..models.bert import BertTokenizer, BertTokenizerFast from ..models.bert import BertTokenizer, BertTokenizerFast
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import PaddingStrategy from ..utils import PaddingStrategy
...@@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin): ...@@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin):
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import torch import torch
if not isinstance(features[0], (dict, BatchEncoding)): if not isinstance(features[0], Mapping):
features = [vars(f) for f in features] features = [vars(f) for f in features]
first = features[0] first = features[0]
batch = {} batch = {}
...@@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: ...@@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
if not isinstance(features[0], (dict, BatchEncoding)): if not isinstance(features[0], Mapping):
features = [vars(f) for f in features] features = [vars(f) for f in features]
first = features[0] first = features[0]
batch = {} batch = {}
...@@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: ...@@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import numpy as np import numpy as np
if not isinstance(features[0], (dict, BatchEncoding)): if not isinstance(features[0], Mapping):
features = [vars(f) for f in features] features = [vars(f) for f in features]
first = features[0] first = features[0]
batch = {} batch = {}
...@@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): ...@@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import tensorflow as tf import tensorflow as tf
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of) batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
else: else:
batch = { batch = {
...@@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): ...@@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of) batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
else: else:
batch = { batch = {
...@@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): ...@@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import numpy as np import numpy as np
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of) batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
else: else:
batch = { batch = {
...@@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
</Tip>""" </Tip>"""
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples] input_ids = [e["input_ids"] for e in examples]
else: else:
input_ids = examples input_ids = examples
...@@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return {"input_ids": inputs, "labels": labels} return {"input_ids": inputs, "labels": labels}
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples] input_ids = [e["input_ids"] for e in examples]
else: else:
input_ids = examples input_ids = examples
...@@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return {"input_ids": inputs, "labels": labels} return {"input_ids": inputs, "labels": labels}
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples] input_ids = [e["input_ids"] for e in examples]
else: else:
input_ids = examples input_ids = examples
...@@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin): ...@@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
return_tensors: str = "pt" return_tensors: str = "pt"
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
examples = [e["input_ids"] for e in examples] examples = [e["input_ids"] for e in examples]
batch = _torch_collate_batch(examples, self.tokenizer) batch = _torch_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch) inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
examples = [e["input_ids"] for e in examples] examples = [e["input_ids"] for e in examples]
batch = _tf_collate_batch(examples, self.tokenizer) batch = _tf_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch) inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], Mapping):
examples = [e["input_ids"] for e in examples] examples = [e["input_ids"] for e in examples]
batch = _numpy_collate_batch(examples, self.tokenizer) batch = _numpy_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch) inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
......
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
import pickle import pickle
import re import re
import warnings import warnings
from collections.abc import Mapping
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import h5py import h5py
...@@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig ...@@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation_tf_utils import TFGenerationMixin from .generation_tf_utils import TFGenerationMixin
from .tf_utils import shape_list from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import ( from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
...@@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs):
raise ValueError( raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}." f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
) )
elif isinstance(input_ids, (dict, BatchEncoding)): elif isinstance(input_ids, Mapping):
if "inputs" in input_ids: if "inputs" in input_ids:
warnings.warn( warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.", "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" TensorFlow Hubert model.""" """ TensorFlow Hubert model."""
import inspect import inspect
import warnings import warnings
from collections.abc import Mapping
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation ...@@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list, stable_softmax from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs): ...@@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise ValueError( raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}." f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
) )
elif isinstance(input_values, (dict, BatchEncoding)): elif isinstance(input_values, Mapping):
if "inputs" in input_values: if "inputs" in input_values:
warnings.warn( warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.", "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import itertools import itertools
import json import json
import os import os
from collections.abc import Mapping
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer): ...@@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer):
""" """
# If we have a list of dicts, let's convert it in a dict of lists # If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding # The model's main input name, usually `input_ids`, has be passed for padding
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import itertools import itertools
import json import json
import os import os
from collections.abc import Mapping
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer): ...@@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
""" """
# If we have a list of dicts, let's convert it in a dict of lists # If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding # The model's main input name, usually `input_ids`, has be passed for padding
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import inspect import inspect
import warnings import warnings
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
...@@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation ...@@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list, stable_softmax from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs): ...@@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise ValueError( raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}." f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
) )
elif isinstance(input_values, (dict, BatchEncoding)): elif isinstance(input_values, Mapping):
if "inputs" in input_values: if "inputs" in input_values:
warnings.warn( warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.", "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
......
...@@ -22,6 +22,7 @@ import shutil ...@@ -22,6 +22,7 @@ import shutil
import sys import sys
import tempfile import tempfile
import unittest import unittest
from collections.abc import Mapping
from distutils.util import strtobool from distutils.util import strtobool
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
...@@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3): ...@@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3):
""" """
import numpy as np import numpy as np
from transformers.tokenization_utils import BatchEncoding
if isinstance(obj, list): if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj] return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, np.ndarray): elif isinstance(obj, np.ndarray):
return nested_simplify(obj.tolist()) return nested_simplify(obj.tolist())
elif isinstance(obj, (dict, BatchEncoding)): elif isinstance(obj, Mapping):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int, np.int64)): elif isinstance(obj, (str, int, np.int64)):
return obj return obj
......
...@@ -24,6 +24,7 @@ import os ...@@ -24,6 +24,7 @@ import os
import re import re
import warnings import warnings
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
...@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
""" """
# If we have a list of dicts, let's convert it in a dict of lists # If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding # The model's main input name, usually `input_ids`, has be passed for padding
......
...@@ -22,6 +22,7 @@ import math ...@@ -22,6 +22,7 @@ import math
import os import os
import sys import sys
import warnings import warnings
from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from logging import StreamHandler from logging import StreamHandler
...@@ -111,7 +112,7 @@ def find_batch_size(tensors): ...@@ -111,7 +112,7 @@ def find_batch_size(tensors):
result = find_batch_size(t) result = find_batch_size(t)
if result is not None: if result is not None:
return result return result
elif isinstance(tensors, (dict, BatchEncoding)): elif isinstance(tensors, Mapping):
for key, value in tensors.items(): for key, value in tensors.items():
result = find_batch_size(value) result = find_batch_size(value)
if result is not None: if result is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment