"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8eb73c872afbe59abab4580aaa591a9851a42e6d"
Unverified Commit bb79470a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Moving `sequence_to_str` to `torchvision._utils` (#5604)

* Moving `sequence_to_str` to `torchvision._utils`

* Fix linter

* Rename test_prototype_utils test to test_internal_utils
parent 5ddd564e
...@@ -19,8 +19,8 @@ import torch ...@@ -19,8 +19,8 @@ import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor from torch.testing import make_tensor as _make_tensor
from torchvision._utils import sequence_to_str
from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu") make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ()) make_scalar = functools.partial(make_tensor, ())
......
import pytest import pytest
from torchvision.prototype.utils._internal import sequence_to_str from torchvision._utils import sequence_to_str
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -10,8 +10,8 @@ from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair ...@@ -10,8 +10,8 @@ from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str
assert_samples_equal = functools.partial( assert_samples_equal = functools.partial(
......
import enum import enum
from typing import TypeVar, Type from typing import Sequence, TypeVar, Type
T = TypeVar("T", bound=enum.Enum) T = TypeVar("T", bound=enum.Enum)
...@@ -18,3 +18,15 @@ class StrEnumMeta(enum.EnumMeta): ...@@ -18,3 +18,15 @@ class StrEnumMeta(enum.EnumMeta):
class StrEnum(enum.Enum, metaclass=StrEnumMeta): class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
...@@ -7,7 +7,8 @@ import pathlib ...@@ -7,7 +7,8 @@ import pathlib
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion
from .._home import use_sharded_dataset from .._home import use_sharded_dataset
from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._internal import BUILTIN_DIR, _make_sharded_datapipe
......
...@@ -28,9 +28,10 @@ from typing import ( ...@@ -28,9 +28,10 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
from torchvision._utils import sequence_to_str
__all__ = [ __all__ = [
"sequence_to_str",
"add_suggestion", "add_suggestion",
"FrozenMapping", "FrozenMapping",
"make_repr", "make_repr",
...@@ -43,18 +44,6 @@ __all__ = [ ...@@ -43,18 +44,6 @@ __all__ = [
] ]
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
def add_suggestion( def add_suggestion(
msg: str, msg: str,
*, *,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment