Unverified Commit bb79470a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Moving `sequence_to_str` to `torchvision._utils` (#5604)

* Moving `sequence_to_str` to `torchvision._utils`

* Fix linter

* Rename test_prototype_utils test to test_internal_utils
parent 5ddd564e
......@@ -19,8 +19,8 @@ import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision._utils import sequence_to_str
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
......
import pytest
from torchvision.prototype.utils._internal import sequence_to_str
from torchvision._utils import sequence_to_str
@pytest.mark.parametrize(
......
......@@ -10,8 +10,8 @@ from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str
assert_samples_equal = functools.partial(
......
import enum
from typing import TypeVar, Type
from typing import Sequence, TypeVar, Type
T = TypeVar("T", bound=enum.Enum)
......@@ -18,3 +18,15 @@ class StrEnumMeta(enum.EnumMeta):
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
......@@ -7,7 +7,8 @@ import pathlib
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
from torchvision._utils import sequence_to_str
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion
from .._home import use_sharded_dataset
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
......
......@@ -28,9 +28,10 @@ from typing import (
import numpy as np
import torch
from torchvision._utils import sequence_to_str
__all__ = [
"sequence_to_str",
"add_suggestion",
"FrozenMapping",
"make_repr",
......@@ -43,18 +44,6 @@ __all__ = [
]
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
def add_suggestion(
msg: str,
*,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment