Unverified Commit 7f4c55b1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

prevent feature wrapping if the feature is not the primary operand (#6095)

* prevent feature wrapping if the feature is not the primary operand

* explicitly add feature tests to CI
parent 9cece405
...@@ -43,6 +43,15 @@ jobs: ...@@ -43,6 +43,15 @@ jobs:
id: setup id: setup
run: exit 0 run: exit 0
- name: Run prototype features tests
shell: bash
run: |
pytest \
--durations=20 \
--cov=torchvision/prototype/features \
--cov-report=term-missing \
test/test_prototype_features*.py
- name: Run prototype datasets tests - name: Run prototype datasets tests
if: success() || ( failure() && steps.setup.conclusion == 'success' ) if: success() || ( failure() && steps.setup.conclusion == 'success' )
shell: bash shell: bash
......
import torch
from torchvision.prototype import features
def test_isinstance():
assert isinstance(
features.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor,
)
def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32)
assert type(label_to) is features.Label
assert label_to.dtype is torch.int32
assert label_to.categories is label.categories
def test_to_feature_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label)
assert type(tensor_to) is torch.Tensor
assert tensor_to.dtype is torch.int32
def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone()
assert type(label_clone) is features.Label
assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories
def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
assert type(output) is torch.Tensor
def test_new_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
label_new = features.Label.new_like(label, output)
assert type(label_new) is features.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories
...@@ -89,6 +89,13 @@ class _Feature(torch.Tensor): ...@@ -89,6 +89,13 @@ class _Feature(torch.Tensor):
with DisableTorchFunction(): with DisableTorchFunction():
output = func(*args, **kwargs) output = func(*args, **kwargs)
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke
# `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a
# case.
if not isinstance(args[0], cls):
return output
if func is torch.Tensor.clone: if func is torch.Tensor.clone:
return cls.new_like(args[0], output) return cls.new_like(args[0], output)
elif func is torch.Tensor.to: elif func is torch.Tensor.to:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment