Unverified Commit 1efb567f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove fn_kwargs from Filter and Mapper datapipes (#5113)

* remove fn_kwargs from Filter and Mapper datapipes

* fix leftovers
parent 40be6576
import functools
import io import io
import pathlib import pathlib
import re import re
...@@ -132,7 +133,7 @@ class Caltech101(Dataset): ...@@ -132,7 +133,7 @@ class Caltech101(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True, keep_key=True,
) )
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
...@@ -185,7 +186,7 @@ class Caltech256(Dataset): ...@@ -185,7 +186,7 @@ class Caltech256(Dataset):
dp = Filter(dp, self._is_not_rogue_file) dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import csv import csv
import functools
import io import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence
...@@ -26,7 +27,6 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,7 +27,6 @@ from torchvision.prototype.datasets.utils._internal import (
hint_shuffling, hint_shuffling,
) )
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
...@@ -155,7 +155,7 @@ class CelebA(Dataset): ...@@ -155,7 +155,7 @@ class CelebA(Dataset):
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
splits_dp = hint_sharding(splits_dp) splits_dp = hint_sharding(splits_dp)
splits_dp = hint_shuffling(splits_dp) splits_dp = hint_shuffling(splits_dp)
...@@ -181,4 +181,4 @@ class CelebA(Dataset): ...@@ -181,4 +181,4 @@ class CelebA(Dataset):
keep_key=True, keep_key=True,
) )
dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
...@@ -89,7 +89,7 @@ class _CifarBase(Dataset): ...@@ -89,7 +89,7 @@ class _CifarBase(Dataset):
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import functools
import io import io
import pathlib import pathlib
import re import re
...@@ -183,12 +184,16 @@ class Coco(Dataset): ...@@ -183,12 +184,16 @@ class Coco(Dataset):
if config.annotations is None: if config.annotations is None:
dp = hint_sharding(images_dp) dp = hint_sharding(images_dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder))
meta_dp = Filter( meta_dp = Filter(
meta_dp, meta_dp,
self._filter_meta_files, functools.partial(
fn_kwargs=dict(split=config.split, year=config.year, annotations=config.annotations), self._filter_meta_files,
split=config.split,
year=config.year,
annotations=config.annotations,
),
) )
meta_dp = JsonParser(meta_dp) meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1)) meta_dp = Mapper(meta_dp, getitem(1))
...@@ -226,7 +231,7 @@ class Coco(Dataset): ...@@ -226,7 +231,7 @@ class Coco(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper( return Mapper(
dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder) dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder)
) )
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
...@@ -235,7 +240,8 @@ class Coco(Dataset): ...@@ -235,7 +240,8 @@ class Coco(Dataset):
dp = resources[1].load(pathlib.Path(root) / self.name) dp = resources[1].load(pathlib.Path(root) / self.name)
dp = Filter( dp = Filter(
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances") dp,
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
) )
dp = JsonParser(dp) dp = JsonParser(dp)
......
import functools
import io import io
import pathlib import pathlib
import re import re
...@@ -165,7 +166,7 @@ class ImageNet(Dataset): ...@@ -165,7 +166,7 @@ class ImageNet(Dataset):
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data) dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
......
...@@ -136,7 +136,7 @@ class _MNISTBase(Dataset): ...@@ -136,7 +136,7 @@ class _MNISTBase(Dataset):
dp = Zipper(images_dp, labels_dp) dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder))
class MNIST(_MNISTBase): class MNIST(_MNISTBase):
......
import functools
import io import io
import pathlib import pathlib
import re import re
...@@ -152,7 +153,7 @@ class SBD(Dataset): ...@@ -152,7 +153,7 @@ class SBD(Dataset):
ref_key_fn=path_accessor("stem"), ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import functools
import io import io
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
...@@ -65,5 +66,5 @@ class SEMEION(Dataset): ...@@ -65,5 +66,5 @@ class SEMEION(Dataset):
dp = CSVParser(dp, delimiter=" ") dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return dp return dp
...@@ -127,7 +127,7 @@ class VOC(Dataset): ...@@ -127,7 +127,7 @@ class VOC(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task])) split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
...@@ -142,4 +142,4 @@ class VOC(Dataset): ...@@ -142,4 +142,4 @@ class VOC(Dataset):
ref_key_fn=path_accessor("stem"), ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
import functools
import io import io
import os import os
import os.path import os.path
...@@ -50,12 +51,12 @@ def from_data_folder( ...@@ -50,12 +51,12 @@ def from_data_folder(
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks) dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = FileLoader(dp) dp = FileLoader(dp)
return ( return (
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)),
categories, categories,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment