Unverified Commit 1efb567f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove fn_kwargs from Filter and Mapper datapipes (#5113)

* remove fn_kwargs from Filter and Mapper datapipes

* fix leftovers
parent 40be6576
import functools
import io
import pathlib
import re
......@@ -132,7 +133,7 @@ class Caltech101(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......@@ -185,7 +186,7 @@ class Caltech256(Dataset):
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import csv
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence
......@@ -26,7 +27,6 @@ from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
)
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
......@@ -155,7 +155,7 @@ class CelebA(Dataset):
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = hint_shuffling(splits_dp)
......@@ -181,4 +181,4 @@ class CelebA(Dataset):
keep_key=True,
)
dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
......@@ -89,7 +89,7 @@ class _CifarBase(Dataset):
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import functools
import io
import pathlib
import re
......@@ -183,12 +184,16 @@ class Coco(Dataset):
if config.annotations is None:
dp = hint_sharding(images_dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder))
meta_dp = Filter(
meta_dp,
self._filter_meta_files,
fn_kwargs=dict(split=config.split, year=config.year, annotations=config.annotations),
functools.partial(
self._filter_meta_files,
split=config.split,
year=config.year,
annotations=config.annotations,
),
)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
......@@ -226,7 +231,7 @@ class Coco(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(
dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder)
dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder)
)
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
......@@ -235,7 +240,8 @@ class Coco(Dataset):
dp = resources[1].load(pathlib.Path(root) / self.name)
dp = Filter(
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances")
dp,
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
)
dp = JsonParser(dp)
......
import functools
import io
import pathlib
import re
......@@ -165,7 +166,7 @@ class ImageNet(Dataset):
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
......
......@@ -136,7 +136,7 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder))
class MNIST(_MNISTBase):
......
import functools
import io
import pathlib
import re
......@@ -152,7 +153,7 @@ class SBD(Dataset):
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple
......@@ -65,5 +66,5 @@ class SEMEION(Dataset):
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return dp
......@@ -127,7 +127,7 @@ class VOC(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
)
split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
......@@ -142,4 +142,4 @@ class VOC(Dataset):
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
import functools
import io
import os
import os.path
......@@ -50,12 +51,12 @@ def from_data_folder(
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = FileLoader(dp)
return (
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),
Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)),
categories,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment