"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "e171bee832efa2ec1c70666922939b3f006f88d8"
Unverified Commit a73285aa authored by Nikita Shulga's avatar Nikita Shulga Committed by GitHub
Browse files

Add device check to `io.decode_image` (#7406)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 4a7def80
...@@ -368,6 +368,13 @@ def test_decode_jpeg_cuda(mode, img_path, scripted): ...@@ -368,6 +368,13 @@ def test_decode_jpeg_cuda(mode, img_path, scripted):
# Some difference expected between jpeg implementations # Some difference expected between jpeg implementations
assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
@needs_cuda
def test_decode_image_cuda_raises():
data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
exception_raised = True
with pytest.raises(RuntimeError):
decode_image(data)
@needs_cuda @needs_cuda
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda"))) @pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
......
...@@ -7,6 +7,8 @@ namespace vision { ...@@ -7,6 +7,8 @@ namespace vision {
namespace image { namespace image {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that tensor is a CPU tensor
TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment