Unverified Commit 18df4407 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Replace dict/BatchEncoding instance checks by Mapping (#17014)

* Replace dict/BatchEncoding instance checks by Mapping

* Typo
parent b8dffd1f
......@@ -14,11 +14,12 @@
import random
import warnings
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from ..models.bert import BertTokenizer, BertTokenizerFast
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import PaddingStrategy
......@@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin):
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import torch
if not isinstance(features[0], (dict, BatchEncoding)):
if not isinstance(features[0], Mapping):
features = [vars(f) for f in features]
first = features[0]
batch = {}
......@@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import numpy as np
import tensorflow as tf
if not isinstance(features[0], (dict, BatchEncoding)):
if not isinstance(features[0], Mapping):
features = [vars(f) for f in features]
first = features[0]
batch = {}
......@@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import numpy as np
if not isinstance(features[0], (dict, BatchEncoding)):
if not isinstance(features[0], Mapping):
features = [vars(f) for f in features]
first = features[0]
batch = {}
......@@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import tensorflow as tf
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {
......@@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {
......@@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
import numpy as np
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {
......@@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
</Tip>"""
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
......@@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return {"input_ids": inputs, "labels": labels}
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
......@@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
return {"input_ids": inputs, "labels": labels}
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
......@@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
return_tensors: str = "pt"
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
examples = [e["input_ids"] for e in examples]
batch = _torch_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
examples = [e["input_ids"] for e in examples]
batch = _tf_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
if isinstance(examples[0], Mapping):
examples = [e["input_ids"] for e in examples]
batch = _numpy_collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
......
......@@ -21,6 +21,7 @@ import os
import pickle
import re
import warnings
from collections.abc import Mapping
from typing import Dict, List, Optional, Union
import h5py
......@@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation_tf_utils import TFGenerationMixin
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import (
DUMMY_INPUTS,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
......@@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs):
raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
)
elif isinstance(input_ids, (dict, BatchEncoding)):
elif isinstance(input_ids, Mapping):
if "inputs" in input_ids:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
......
......@@ -15,6 +15,7 @@
""" TensorFlow Hubert model."""
import inspect
import warnings
from collections.abc import Mapping
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
......@@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
)
elif isinstance(input_values, (dict, BatchEncoding)):
elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
......
......@@ -17,6 +17,7 @@
import itertools
import json
import os
from collections.abc import Mapping
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer):
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
......
......@@ -18,6 +18,7 @@
import itertools
import json
import os
from collections.abc import Mapping
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
......@@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
......
......@@ -16,6 +16,7 @@
import inspect
import warnings
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
......@@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list, stable_softmax
from ...tokenization_utils_base import BatchEncoding
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs):
raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
)
elif isinstance(input_values, (dict, BatchEncoding)):
elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
......
......@@ -22,6 +22,7 @@ import shutil
import sys
import tempfile
import unittest
from collections.abc import Mapping
from distutils.util import strtobool
from io import StringIO
from pathlib import Path
......@@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3):
"""
import numpy as np
from transformers.tokenization_utils import BatchEncoding
if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, np.ndarray):
return nested_simplify(obj.tolist())
elif isinstance(obj, (dict, BatchEncoding)):
elif isinstance(obj, Mapping):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int, np.int64)):
return obj
......
......@@ -24,6 +24,7 @@ import os
import re
import warnings
from collections import OrderedDict, UserDict
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
......@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
......
......@@ -22,6 +22,7 @@ import math
import os
import sys
import warnings
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from logging import StreamHandler
......@@ -111,7 +112,7 @@ def find_batch_size(tensors):
result = find_batch_size(t)
if result is not None:
return result
elif isinstance(tensors, (dict, BatchEncoding)):
elif isinstance(tensors, Mapping):
for key, value in tensors.items():
result = find_batch_size(value)
if result is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment