You need to sign in or sign up before continuing.
Unverified Commit abc6c778 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix and add test for sequence_to_str (#5213)

* fix and add test for sequence_to_str

* remove manual ids
parent afdf1261
import pytest
from torchvision.prototype.utils._internal import sequence_to_str
@pytest.mark.parametrize(
("seq", "separate_last", "expected"),
[
([], "", ""),
(["foo"], "", "'foo'"),
(["foo", "bar"], "", "'foo', 'bar'"),
(["foo", "bar"], "and ", "'foo' and 'bar'"),
(["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'"),
(["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'"),
],
)
def test_sequence_to_str(seq, separate_last, expected):
assert sequence_to_str(seq, separate_last=separate_last) == expected
...@@ -30,10 +30,15 @@ class StrEnum(enum.Enum, metaclass=StrEnumMeta): ...@@ -30,10 +30,15 @@ class StrEnum(enum.Enum, metaclass=StrEnumMeta):
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1: if len(seq) == 1:
return f"'{seq[0]}'" return f"'{seq[0]}'"
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'""" head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
return head + tail
def add_suggestion( def add_suggestion(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment