"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "804c2974d5e1c95e71afe57f8f97b3a8bcd921eb"
Unverified Commit 91ff480e authored by Allan Lin's avatar Allan Lin Committed by GitHub
Browse files

Update namespaces inside torch.utils.data to the latest. (#13167)

* Update torch.utils.data namespaces to the latest.

* Format

* Update Dataloader.

* Style
parent 1fec32ad
...@@ -77,7 +77,7 @@ class Split(Enum): ...@@ -77,7 +77,7 @@ class Split(Enum):
if is_torch_available(): if is_torch_available():
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
class MultipleChoiceDataset(Dataset): class MultipleChoiceDataset(Dataset):
""" """
......
...@@ -141,7 +141,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -141,7 +141,7 @@ class Seq2SeqTrainer(Trainer):
) )
return scheduler return scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None return None
elif is_torch_tpu_available(): elif is_torch_tpu_available():
......
...@@ -206,7 +206,7 @@ class TokenClassificationTask: ...@@ -206,7 +206,7 @@ class TokenClassificationTask:
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
class TokenClassificationDataset(Dataset): class TokenClassificationDataset(Dataset):
""" """
......
...@@ -31,7 +31,7 @@ import random ...@@ -31,7 +31,7 @@ import random
import datasets import datasets
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -31,7 +31,7 @@ import random ...@@ -31,7 +31,7 @@ import random
import datasets import datasets
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -29,7 +29,7 @@ from typing import Optional, Union ...@@ -29,7 +29,7 @@ from typing import Optional, Union
import datasets import datasets
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -28,7 +28,7 @@ import datasets ...@@ -28,7 +28,7 @@ import datasets
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -28,7 +28,7 @@ import datasets ...@@ -28,7 +28,7 @@ import datasets
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -29,7 +29,7 @@ import nltk ...@@ -29,7 +29,7 @@ import nltk
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -21,7 +21,7 @@ import random ...@@ -21,7 +21,7 @@ import random
import datasets import datasets
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -27,7 +27,7 @@ import random ...@@ -27,7 +27,7 @@ import random
import datasets import datasets
import torch import torch
from datasets import ClassLabel, load_dataset, load_metric from datasets import ClassLabel, load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -28,7 +28,7 @@ import datasets ...@@ -28,7 +28,7 @@ import datasets
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers
......
...@@ -88,7 +88,7 @@ class InputFeatures: ...@@ -88,7 +88,7 @@ class InputFeatures:
if is_torch_available(): if is_torch_available():
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
class HansDataset(Dataset): class HansDataset(Dataset):
""" """
......
...@@ -19,7 +19,7 @@ import copy ...@@ -19,7 +19,7 @@ import copy
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.data import BatchSampler, Sampler
from utils import logger from utils import logger
......
...@@ -20,7 +20,7 @@ from enum import Enum ...@@ -20,7 +20,7 @@ from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from filelock import FileLock from filelock import FileLock
......
...@@ -21,7 +21,7 @@ import warnings ...@@ -21,7 +21,7 @@ import warnings
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from filelock import FileLock from filelock import FileLock
......
...@@ -19,7 +19,7 @@ from enum import Enum ...@@ -19,7 +19,7 @@ from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from filelock import FileLock from filelock import FileLock
......
...@@ -49,10 +49,8 @@ import numpy as np ...@@ -49,10 +49,8 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from . import __version__ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -206,16 +204,16 @@ class Trainer: ...@@ -206,16 +204,16 @@ class Trainer:
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`. The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise. :func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`): train_dataset (:obj:`torch.utils.data.Dataset` or :obj:`torch.utils.data.IterableDataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. ``model.forward()`` method are automatically removed.
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are Note that if it's a :obj:`torch.utils.data.IterableDataset` with some randomization and you are training in
training in a distributed fashion, your iterable dataset should either use a internal attribute a distributed fashion, your iterable dataset should either use a internal attribute :obj:`generator` that
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identical on all is a :obj:`torch.Generator` for the randomization that must be identical on all processes (and the Trainer
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a will manually set the seed of this :obj:`generator` at each epoch) or have a :obj:`set_epoch()` method that
:obj:`set_epoch()` method that internally sets the seed of the RNGs used. internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. ``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`): tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
...@@ -537,7 +535,7 @@ class Trainer: ...@@ -537,7 +535,7 @@ class Trainer:
else: else:
return dataset.remove_columns(ignored_columns) return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized): if not isinstance(self.train_dataset, collections.abc.Sized):
return None return None
...@@ -617,7 +615,7 @@ class Trainer: ...@@ -617,7 +615,7 @@ class Trainer:
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training") train_dataset = self._remove_unused_columns(train_dataset, description="training")
if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset): if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1: if self.args.world_size > 1:
train_dataset = IterableDatasetShard( train_dataset = IterableDatasetShard(
train_dataset, train_dataset,
...@@ -647,7 +645,7 @@ class Trainer: ...@@ -647,7 +645,7 @@ class Trainer:
pin_memory=self.args.dataloader_pin_memory, pin_memory=self.args.dataloader_pin_memory,
) )
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code # Deprecated code
if self.args.use_legacy_prediction_loop: if self.args.use_legacy_prediction_loop:
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -683,7 +681,7 @@ class Trainer: ...@@ -683,7 +681,7 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
Args: Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
""" """
...@@ -694,7 +692,7 @@ class Trainer: ...@@ -694,7 +692,7 @@ class Trainer:
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset): if isinstance(eval_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1: if self.args.world_size > 1:
eval_dataset = IterableDatasetShard( eval_dataset = IterableDatasetShard(
eval_dataset, eval_dataset,
...@@ -730,14 +728,14 @@ class Trainer: ...@@ -730,14 +728,14 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
Args: Args:
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): test_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
""" """
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test") test_dataset = self._remove_unused_columns(test_dataset, description="test")
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset): if isinstance(test_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1: if self.args.world_size > 1:
test_dataset = IterableDatasetShard( test_dataset = IterableDatasetShard(
test_dataset, test_dataset,
......
...@@ -175,9 +175,9 @@ class TrainerCallback: ...@@ -175,9 +175,9 @@ class TrainerCallback:
The optimizer used for the training steps. The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`): lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
The scheduler used for setting the learning rate. The scheduler used for setting the learning rate.
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`): train_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training. The current dataloader used for training.
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`): eval_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training. The current dataloader used for training.
metrics (:obj:`Dict[str, float]`): metrics (:obj:`Dict[str, float]`):
The metrics computed by the last evaluation phase. The metrics computed by the last evaluation phase.
......
...@@ -29,9 +29,8 @@ import numpy as np ...@@ -29,9 +29,8 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
...@@ -290,7 +289,7 @@ class SequentialDistributedSampler(Sampler): ...@@ -290,7 +289,7 @@ class SequentialDistributedSampler(Sampler):
return self.num_samples return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int): def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
if xm.xrt_world_size() <= 1: if xm.xrt_world_size() <= 1:
return RandomSampler(dataset) return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
...@@ -690,7 +689,7 @@ class IterableDatasetShard(IterableDataset): ...@@ -690,7 +689,7 @@ class IterableDatasetShard(IterableDataset):
Args: Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`): dataset (:obj:`torch.utils.data.IterableDataset`):
The batch sampler to split in several shards. The batch sampler to split in several shards.
batch_size (:obj:`int`, `optional`, defaults to 1): batch_size (:obj:`int`, `optional`, defaults to 1):
The size of the batches per shard. The size of the batches per shard.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment