"vscode:/vscode.git/clone" did not exist on "865233e2565fa4cbb89e806bf371866f4ef9d56f"
Unverified Commit 7f4c55b1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

prevent feature wrapping if the feature is not the primary operand (#6095)

* prevent feature wrapping if the feature is not the primary operand

* explicitly add feature tests to CI
parent 9cece405
......@@ -43,6 +43,15 @@ jobs:
id: setup
run: exit 0
- name: Run prototype features tests
shell: bash
run: |
pytest \
--durations=20 \
--cov=torchvision/prototype/features \
--cov-report=term-missing \
test/test_prototype_features*.py
- name: Run prototype datasets tests
if: success() || ( failure() && steps.setup.conclusion == 'success' )
shell: bash
......
import torch
from torchvision.prototype import features
def test_isinstance():
assert isinstance(
features.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor,
)
def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32)
assert type(label_to) is features.Label
assert label_to.dtype is torch.int32
assert label_to.categories is label.categories
def test_to_feature_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label)
assert type(tensor_to) is torch.Tensor
assert tensor_to.dtype is torch.int32
def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone()
assert type(label_clone) is features.Label
assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories
def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
assert type(output) is torch.Tensor
def test_new_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
label_new = features.Label.new_like(label, output)
assert type(label_new) is features.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories
......@@ -89,6 +89,13 @@ class _Feature(torch.Tensor):
with DisableTorchFunction():
output = func(*args, **kwargs)
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke
# `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a
# case.
if not isinstance(args[0], cls):
return output
if func is torch.Tensor.clone:
return cls.new_like(args[0], output)
elif func is torch.Tensor.to:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment