Unverified Commit fe78a8ae authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype features (#4721)

* add prototype features

* add some JIT tests

* refactor input data handling

* refactor tests

* cleanup tests

* add BoundingBox feature

* mypy

* xfail torchscript tests for now

* cleanup

* fix imports
parent 49ec677c
...@@ -15,7 +15,7 @@ from torch.testing import make_tensor as _make_tensor ...@@ -15,7 +15,7 @@ from torch.testing import make_tensor as _make_tensor
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import add_suggestion from torchvision.prototype.utils._internal import add_suggestion
make_tensor = functools.partial(_make_tensor, device="cpu") make_tensor = functools.partial(_make_tensor, device="cpu")
......
...@@ -5,7 +5,7 @@ import builtin_dataset_mocks ...@@ -5,7 +5,7 @@ import builtin_dataset_mocks
import pytest import pytest
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
_loaders = [] _loaders = []
......
import functools
import itertools
import pytest
import torch
from torch.testing import make_tensor as _make_tensor, assert_close
from torchvision.prototype import features
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)
def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, ())
y1 = torch.randint(0, height // 2, ())
x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1
y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, ())
y = torch.randint(0, height // 2, ())
w = torch.randint(1, width - int(x), ())
h = torch.randint(1, height - int(y), ())
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = torch.randint(1, min(int(cx), width - int(cx)), ())
h = torch.randint(1, min(int(cy), height - int(cy)), ())
parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL:
parts = make_tensor((4,)).unbind()
return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size)
MAKE_DATA_MAP = {
features.BoundingBox: make_bounding_box,
}
def make_feature(feature_type, **meta_data):
maker = MAKE_DATA_MAP.get(feature_type, lambda **meta_data: feature_type(make_tensor(()), **meta_data))
return maker(**meta_data)
class TestCommon:
FEATURE_TYPES, NON_DEFAULT_META_DATA = zip(
*(
(features.Image, dict(color_space=features.ColorSpace._SENTINEL)),
(features.Label, dict(category="category")),
(features.BoundingBox, dict(format=features.BoundingBoxFormat._SENTINEL, image_size=(-1, -1))),
)
)
feature_types = pytest.mark.parametrize(
"feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__
)
features = pytest.mark.parametrize(
"feature",
[
pytest.param(make_feature(feature_type, **meta_data), id=feature_type.__name__)
for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA)
],
)
def test_consistency(self):
builtin_feature_types = {
name
for name, feature_type in features.__dict__.items()
if not name.startswith("_")
and isinstance(feature_type, type)
and issubclass(feature_type, features.Feature)
and feature_type is not features.Feature
}
untested_feature_types = builtin_feature_types - {feature_type.__name__ for feature_type in self.FEATURE_TYPES}
if untested_feature_types:
raise AssertionError(
f"The feature(s) {sequence_to_str(sorted(untested_feature_types), separate_last='and ')} "
f"is/are exposed at `torchvision.prototype.features`, but is/are not tested by `TestCommon`. "
f"Please add it/them to `TestCommon.FEATURE_TYPES`."
)
@features
def test_meta_data_attribute_access(self, feature):
for name, value in feature._meta_data.items():
assert getattr(feature, name) == feature._meta_data[name]
@feature_types
def test_torch_function(self, feature_type):
input = make_feature(feature_type)
# This can be any Tensor operation besides clone
output = input + 1
assert type(output) is torch.Tensor
assert_close(output, input + 1)
@feature_types
def test_clone(self, feature_type):
input = make_feature(feature_type)
output = input.clone()
assert type(output) is feature_type
assert_close(output, input)
assert output._meta_data == input._meta_data
@features
def test_serialization(self, tmpdir, feature):
file = tmpdir / "test_serialization.pt"
torch.save(feature, str(file))
loaded_feature = torch.load(str(file))
assert isinstance(loaded_feature, type(feature))
assert_close(loaded_feature, feature)
assert loaded_feature._meta_data == feature._meta_data
@features
def test_repr(self, feature):
assert type(feature).__name__ in repr(feature)
class TestBoundingBox:
@pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2))
def test_cycle_consistency(self, format, intermediate_format):
input = make_bounding_box(format=format)
output = input.convert(intermediate_format).convert(format)
assert_close(input, output)
# For now, tensor subclasses with additional meta data do not work with torchscript.
# See https://github.com/pytorch/vision/pull/4721#discussion_r741676037.
@pytest.mark.xfail
class TestJit:
def test_bounding_box(self):
def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
old_height, old_width = input.image_size
new_height, new_width = size
height_scale = new_height / old_height
width_scale = new_width / old_width
old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts()
new_x1 = old_x1 * width_scale
new_y1 = old_y1 * height_scale
new_x2 = old_x2 * width_scale
new_y2 = old_y2 * height_scale
return features.BoundingBox.from_parts(
new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=tuple(size.tolist())
)
def horizontal_flip(input: features.BoundingBox) -> features.BoundingBox:
x, y, w, h = input.convert("xywh").to_parts()
x = input.image_size[1] - (x + w)
return features.BoundingBox.from_parts(x, y, w, h, like=input, format="xywh")
def compose(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
return horizontal_flip(resize(input, size)).convert("xyxy")
image_size = (8, 6)
input = features.BoundingBox([2, 4, 2, 4], format="cxcywh", image_size=image_size)
size = torch.tensor((4, 12))
expected = features.BoundingBox([6, 1, 10, 3], format="xyxy", image_size=image_size)
actual_eager = compose(input, size)
assert_close(actual_eager, expected)
sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5)))
actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size)
assert_close(actual_jit, expected)
from . import datasets from . import datasets
from . import features
from . import models from . import models
from . import transforms from . import transforms
from . import utils
...@@ -7,7 +7,7 @@ from torch.utils.data import IterDataPipe ...@@ -7,7 +7,7 @@ from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import raw, pil from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils._internal import add_suggestion from torchvision.prototype.utils._internal import add_suggestion
from . import _builtin from . import _builtin
......
...@@ -31,6 +31,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -31,6 +31,7 @@ from torchvision.prototype.datasets.utils._internal import (
Decompressor, Decompressor,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
) )
from torchvision.prototype.features import Image, Label
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
...@@ -126,15 +127,14 @@ class _MNISTBase(Dataset): ...@@ -126,15 +127,14 @@ class _MNISTBase(Dataset):
image, label = data image, label = data
if decoder is raw: if decoder is raw:
image = image.unsqueeze(0) image = Image(image)
else: else:
image_buffer = image_buffer_from_array(image.numpy()) image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
category = self.info.categories[int(label)] label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)])
label = label.to(torch.int64)
return dict(image=image, category=category, label=label) return dict(image=image, label=label)
def _make_datapipe( def _make_datapipe(
self, self,
......
import io import io
from typing import cast
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features
from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.functional import pil_to_tensor
__all__ = ["raw", "pil"] __all__ = ["raw", "pil"]
...@@ -12,5 +12,5 @@ def raw(buffer: io.IOBase) -> torch.Tensor: ...@@ -12,5 +12,5 @@ def raw(buffer: io.IOBase) -> torch.Tensor:
raise RuntimeError("This is just a sentinel and should never be called.") raise RuntimeError("This is just a sentinel and should never be called.")
def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: def pil(buffer: io.IOBase) -> features.Image:
return cast(torch.Tensor, pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper()))) return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
...@@ -9,10 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple ...@@ -9,10 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str
add_suggestion,
sequence_to_str,
)
from .._home import use_sharded_dataset from .._home import use_sharded_dataset
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe
......
import collections.abc
import csv import csv
import difflib
import enum import enum
import gzip import gzip
import io import io
...@@ -11,7 +9,6 @@ import pathlib ...@@ -11,7 +9,6 @@ import pathlib
import pickle import pickle
import textwrap import textwrap
from typing import ( from typing import (
Collection,
Sequence, Sequence,
Callable, Callable,
Union, Union,
...@@ -41,8 +38,6 @@ from torchdata.datapipes.utils import StreamWrapper ...@@ -41,8 +38,6 @@ from torchdata.datapipes.utils import StreamWrapper
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE", "INFINITE_BUFFER_SIZE",
"BUILTIN_DIR", "BUILTIN_DIR",
"sequence_to_str",
"add_suggestion",
"make_repr", "make_repr",
"FrozenMapping", "FrozenMapping",
"FrozenBunch", "FrozenBunch",
...@@ -67,33 +62,6 @@ INFINITE_BUFFER_SIZE = 1_000_000_000 ...@@ -67,33 +62,6 @@ INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
return f"'{seq[0]}'"
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'."""
def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg
return f"{msg.strip()} {hint}"
def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str: def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items]) return sep.join([f"{key}={value}" for key, value in items])
......
from ._bounding_box import BoundingBoxFormat, BoundingBox
from ._feature import Feature
from ._image import Image, ColorSpace
from ._label import Label
import enum
import functools
from typing import Callable, Union, Tuple, Dict, Any, Optional, cast
import torch
from torchvision.prototype.utils._internal import StrEnum
from ._feature import Feature, DEFAULT
class BoundingBoxFormat(StrEnum):
# this is just for test purposes
_SENTINEL = -1
XYXY = enum.auto()
XYWH = enum.auto()
CXCYWH = enum.auto()
def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return input.unbind(dim=-1) # type: ignore[return-value]
def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
return torch.stack((a, b, c, d), dim=-1)
def format_converter_wrapper(
part_converter: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
):
def wrapper(input: torch.Tensor) -> torch.Tensor:
return from_parts(*part_converter(*to_parts(input)))
return wrapper
@format_converter_wrapper
def xywh_to_xyxy(
x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = x
y1 = y
x2 = x + w
y2 = y + h
return x1, y1, x2, y2
@format_converter_wrapper
def xyxy_to_xywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = x1
y = y1
w = x2 - x1
h = y2 - y1
return x, y, w, h
@format_converter_wrapper
def cxcywh_to_xyxy(
cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return x1, y1, x2, y2
@format_converter_wrapper
def xyxy_to_cxcywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return cx, cy, w, h
class BoundingBox(Feature):
formats = BoundingBoxFormat
format: BoundingBoxFormat
image_size: Tuple[int, int]
@classmethod
def _parse_meta_data(
cls,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment]
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
if isinstance(format, str):
format = BoundingBoxFormat[format]
format_fallback = BoundingBoxFormat.XYXY
return dict(
format=(format, format_fallback),
image_size=(image_size, functools.partial(cls.guess_image_size, format=format_fallback)),
)
_TO_XYXY_MAP = {
BoundingBoxFormat.XYWH: xywh_to_xyxy,
BoundingBoxFormat.CXCYWH: cxcywh_to_xyxy,
}
_FROM_XYXY_MAP = {
BoundingBoxFormat.XYWH: xyxy_to_xywh,
BoundingBoxFormat.CXCYWH: xyxy_to_cxcywh,
}
@classmethod
def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]:
if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH):
if format != BoundingBoxFormat.XYXY:
data = cls._TO_XYXY_MAP[format](data)
data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data)
*_, w, h = to_parts(data)
return int(h.ceil()), int(w.ceil())
@classmethod
def from_parts(
cls,
a,
b,
c,
d,
*,
like: Optional["BoundingBox"] = None,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment]
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment]
) -> "BoundingBox":
return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format)
def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return to_parts(self)
def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
if isinstance(format, str):
format = BoundingBoxFormat[format]
if format == self.format:
return cast(BoundingBox, self.clone())
data = self
if self.format != BoundingBoxFormat.XYXY:
data = self._TO_XYXY_MAP[self.format](data)
if format != BoundingBoxFormat.XYXY:
data = self._FROM_XYXY_MAP[format](data)
return BoundingBox(data, like=self, format=format)
from typing import Tuple, cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence
import torch
from torch._C import _TensorBase, DisableTorchFunction
from torchvision.prototype.utils._internal import add_suggestion
F = TypeVar("F", bound="Feature")
DEFAULT = object()
class Feature(torch.Tensor):
_META_ATTRS: Set[str]
_meta_data: Dict[str, Any]
def __init_subclass__(cls):
if not hasattr(cls, "_META_ATTRS"):
cls._META_ATTRS = {
attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")
}
for attr in cls._META_ATTRS:
if not hasattr(cls, attr):
setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr]))
def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs):
unknown_meta_attrs = kwargs.keys() - cls._META_ATTRS
if unknown_meta_attrs:
unknown_meta_attr = sorted(unknown_meta_attrs)[0]
raise TypeError(
add_suggestion(
f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.",
word=unknown_meta_attr,
possibilities=cls._META_ATTRS,
)
)
if like is not None:
dtype = dtype or like.dtype
device = device or like.device
data = cls._to_tensor(data, dtype=dtype, device=device)
requires_grad = False
self = torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)
meta_data = dict()
for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items():
if explicit is not DEFAULT:
value = explicit
elif like is not None:
value = getattr(like, attr)
else:
value = fallback(data) if callable(fallback) else fallback
meta_data[attr] = value
self._meta_data = meta_data
return self
@classmethod
def _to_tensor(cls, data, *, dtype, device):
return torch.as_tensor(data, dtype=dtype, device=device)
@classmethod
def _parse_meta_data(cls) -> Dict[str, Tuple[Any, Any]]:
return dict()
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
with DisableTorchFunction():
output = func(*args, **(kwargs or dict()))
if func is not torch.Tensor.clone:
return output
return cls(output, like=args[0])
def __repr__(self):
return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__)
from typing import Dict, Any, Union, Tuple
import torch
from torchvision.prototype.utils._internal import StrEnum
from ._feature import Feature, DEFAULT
class ColorSpace(StrEnum):
# this is just for test purposes
_SENTINEL = -1
OTHER = 0
GRAYSCALE = 1
RGB = 3
class Image(Feature):
color_spaces = ColorSpace
color_space: ColorSpace
@classmethod
def _parse_meta_data(
cls,
color_space: Union[str, ColorSpace] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
if isinstance(color_space, str):
color_space = ColorSpace[color_space]
return dict(color_space=(color_space, cls.guess_color_space))
@staticmethod
def guess_color_space(data: torch.Tensor) -> ColorSpace:
if data.ndim < 2:
return ColorSpace.OTHER
elif data.ndim == 2:
return ColorSpace.GRAYSCALE
num_channels = data.shape[-3]
if num_channels == 1:
return ColorSpace.GRAYSCALE
elif num_channels == 3:
return ColorSpace.RGB
else:
return ColorSpace.OTHER
from typing import Dict, Any, Optional, Tuple
from ._feature import Feature, DEFAULT
class Label(Feature):
category: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None))
import collections.abc
import difflib
import enum
from typing import Sequence, Collection, Callable
__all__ = ["StrEnum", "sequence_to_str", "add_suggestion"]
class StrEnumMeta(enum.EnumMeta):
def __getitem__(self, item):
return super().__getitem__(item.upper() if isinstance(item, str) else item)
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
return f"'{seq[0]}'"
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'."""
def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg
return f"{msg.strip()} {hint}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment