"ts/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a16e570ddb6e976a78fc1234a0697019fd836ed1"
Unverified Commit 91ff480e authored by Allan Lin's avatar Allan Lin Committed by GitHub
Browse files

Update namespaces inside torch.utils.data to the latest. (#13167)

* Update torch.utils.data namespaces to the latest.

* Format

* Update Dataloader.

* Style
parent 1fec32ad
......@@ -77,7 +77,7 @@ class Split(Enum):
if is_torch_available():
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
class MultipleChoiceDataset(Dataset):
"""
......
......@@ -141,7 +141,7 @@ class Seq2SeqTrainer(Trainer):
)
return scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
......
......@@ -206,7 +206,7 @@ class TokenClassificationTask:
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
class TokenClassificationDataset(Dataset):
"""
......
......@@ -31,7 +31,7 @@ import random
import datasets
import torch
from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -31,7 +31,7 @@ import random
import datasets
import torch
from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -29,7 +29,7 @@ from typing import Optional, Union
import datasets
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -28,7 +28,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -28,7 +28,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -29,7 +29,7 @@ import nltk
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -21,7 +21,7 @@ import random
import datasets
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -27,7 +27,7 @@ import random
import datasets
import torch
from datasets import ClassLabel, load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -28,7 +28,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
......
......@@ -88,7 +88,7 @@ class InputFeatures:
if is_torch_available():
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
class HansDataset(Dataset):
"""
......
......@@ -19,7 +19,7 @@ import copy
from collections import defaultdict
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler
from torch.utils.data import BatchSampler, Sampler
from utils import logger
......
......@@ -20,7 +20,7 @@ from enum import Enum
from typing import List, Optional, Union
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from filelock import FileLock
......
......@@ -21,7 +21,7 @@ import warnings
from typing import Dict, List, Optional
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from filelock import FileLock
......
......@@ -19,7 +19,7 @@ from enum import Enum
from typing import Dict, List, Optional, Union
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset
from filelock import FileLock
......
......@@ -49,10 +49,8 @@ import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PretrainedConfig
......@@ -206,16 +204,16 @@ class Trainer:
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`):
train_dataset (:obj:`torch.utils.data.Dataset` or :obj:`torch.utils.data.IterableDataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
training in a distributed fashion, your iterable dataset should either use a internal attribute
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identical on all
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
:obj:`set_epoch()` method that internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
Note that if it's a :obj:`torch.utils.data.IterableDataset` with some randomization and you are training in
a distributed fashion, your iterable dataset should either use a internal attribute :obj:`generator` that
is a :obj:`torch.Generator` for the randomization that must be identical on all processes (and the Trainer
will manually set the seed of this :obj:`generator` at each epoch) or have a :obj:`set_epoch()` method that
internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
......@@ -537,7 +535,7 @@ class Trainer:
else:
return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
return None
......@@ -617,7 +615,7 @@ class Trainer:
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset):
if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
train_dataset,
......@@ -647,7 +645,7 @@ class Trainer:
pin_memory=self.args.dataloader_pin_memory,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code
if self.args.use_legacy_prediction_loop:
if is_torch_tpu_available():
......@@ -683,7 +681,7 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior.
Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
......@@ -694,7 +692,7 @@ class Trainer:
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset):
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
eval_dataset = IterableDatasetShard(
eval_dataset,
......@@ -730,14 +728,14 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior.
Args:
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
test_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset):
if isinstance(test_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
test_dataset = IterableDatasetShard(
test_dataset,
......
......@@ -175,9 +175,9 @@ class TrainerCallback:
The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
The scheduler used for setting the learning rate.
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
train_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training.
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
eval_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training.
metrics (:obj:`Dict[str, float]`):
The metrics computed by the last evaluation phase.
......
......@@ -29,9 +29,8 @@ import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .tokenization_utils_base import BatchEncoding
......@@ -290,7 +289,7 @@ class SequentialDistributedSampler(Sampler):
return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
......@@ -690,7 +689,7 @@ class IterableDatasetShard(IterableDataset):
Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
dataset (:obj:`torch.utils.data.IterableDataset`):
The batch sampler to split in several shards.
batch_size (:obj:`int`, `optional`, defaults to 1):
The size of the batches per shard.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment