Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6afb3496
Unverified
Commit
6afb3496
authored
Sep 28, 2020
by
Francisco Massa
Committed by
GitHub
Sep 28, 2020
Browse files
Add decode_image op (#2718)
* Add decode_image op * Fix lint * More lint * Add C10_EXPORT
parent
898802fe
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
112 additions
and
33 deletions
+112
-33
test/test_image.py
test/test_image.py
+17
-4
torchvision/csrc/cpu/image/image.cpp
torchvision/csrc/cpu/image/image.cpp
+2
-1
torchvision/csrc/cpu/image/image.h
torchvision/csrc/cpu/image/image.h
+1
-1
torchvision/csrc/cpu/image/read_image_cpu.cpp
torchvision/csrc/cpu/image/read_image_cpu.cpp
+27
-0
torchvision/csrc/cpu/image/read_image_cpu.h
torchvision/csrc/cpu/image/read_image_cpu.h
+6
-0
torchvision/csrc/cpu/image/readjpeg_cpu.cpp
torchvision/csrc/cpu/image/readjpeg_cpu.cpp
+7
-0
torchvision/csrc/cpu/image/readpng_cpu.cpp
torchvision/csrc/cpu/image/readpng_cpu.cpp
+7
-0
torchvision/io/image.py
torchvision/io/image.py
+45
-27
No files found.
test/test_image.py
View file @
6afb3496
...
...
@@ -8,7 +8,7 @@ import torch
import
torchvision
from
PIL
import
Image
from
torchvision.io.image
import
(
read_png
,
decode_png
,
read_jpeg
,
decode_jpeg
,
encode_jpeg
,
write_jpeg
)
read_png
,
decode_png
,
read_jpeg
,
decode_jpeg
,
encode_jpeg
,
write_jpeg
,
decode_image
,
_read_file
)
import
numpy
as
np
IMAGE_ROOT
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"assets"
)
...
...
@@ -44,10 +44,10 @@ class ImageTester(unittest.TestCase):
img_ljpeg
=
decode_jpeg
(
torch
.
from_file
(
img_path
,
dtype
=
torch
.
uint8
,
size
=
size
))
self
.
assertTrue
(
img_ljpeg
.
equal
(
img_pil
))
with
self
.
assertRaisesRegex
(
Valu
eError
,
"Expected a non empty 1-dimensional tensor
.
"
):
with
self
.
assertRaisesRegex
(
Runtim
eError
,
"Expected a non empty 1-dimensional tensor"
):
decode_jpeg
(
torch
.
empty
((
100
,
1
),
dtype
=
torch
.
uint8
))
with
self
.
assertRaisesRegex
(
Valu
eError
,
"Expected a torch.uint8 tensor
.
"
):
with
self
.
assertRaisesRegex
(
Runtim
eError
,
"Expected a torch.uint8 tensor"
):
decode_jpeg
(
torch
.
empty
((
100
,
),
dtype
=
torch
.
float16
))
with
self
.
assertRaises
(
RuntimeError
):
...
...
@@ -149,11 +149,24 @@ class ImageTester(unittest.TestCase):
img_lpng
=
decode_png
(
torch
.
from_file
(
img_path
,
dtype
=
torch
.
uint8
,
size
=
size
))
self
.
assertTrue
(
img_lpng
.
equal
(
img_pil
))
with
self
.
assertRaises
(
Valu
eError
):
with
self
.
assertRaises
(
Runtim
eError
):
decode_png
(
torch
.
empty
((),
dtype
=
torch
.
uint8
))
with
self
.
assertRaises
(
RuntimeError
):
decode_png
(
torch
.
randint
(
3
,
5
,
(
300
,),
dtype
=
torch
.
uint8
))
def
test_decode_image
(
self
):
for
img_path
in
get_images
(
IMAGE_ROOT
,
".jpg"
):
img_pil
=
torch
.
load
(
img_path
.
replace
(
'jpg'
,
'pth'
))
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
img_ljpeg
=
decode_image
(
_read_file
(
img_path
))
self
.
assertTrue
(
img_ljpeg
.
equal
(
img_pil
))
for
img_path
in
get_images
(
IMAGE_DIR
,
".png"
):
img_pil
=
torch
.
from_numpy
(
np
.
array
(
Image
.
open
(
img_path
)))
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
img_lpng
=
decode_image
(
_read_file
(
img_path
))
self
.
assertTrue
(
img_lpng
.
equal
(
img_pil
))
if
__name__
==
'__main__'
:
unittest
.
main
()
torchvision/csrc/cpu/image/image.cpp
View file @
6afb3496
...
...
@@ -16,4 +16,5 @@ static auto registry = torch::RegisterOperators()
.
op
(
"image::decode_png"
,
&
decodePNG
)
.
op
(
"image::decode_jpeg"
,
&
decodeJPEG
)
.
op
(
"image::encode_jpeg"
,
&
encodeJPEG
)
.
op
(
"image::write_jpeg"
,
&
writeJPEG
);
.
op
(
"image::write_jpeg"
,
&
writeJPEG
)
.
op
(
"image::decode_image"
,
&
decode_image
);
torchvision/csrc/cpu/image/image.h
View file @
6afb3496
#pragma once
// Comment
#include <torch/script.h>
#include <torch/torch.h>
#include "read_image_cpu.h"
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include "writejpeg_cpu.h"
torchvision/csrc/cpu/image/read_image_cpu.cpp
0 → 100644
View file @
6afb3496
#include "read_image_cpu.h"
#include <string.h>
torch
::
Tensor
decode_image
(
const
torch
::
Tensor
&
data
)
{
// Check that the input tensor dtype is uint8
TORCH_CHECK
(
data
.
dtype
()
==
torch
::
kU8
,
"Expected a torch.uint8 tensor"
);
// Check that the input tensor is 1-dimensional
TORCH_CHECK
(
data
.
dim
()
==
1
&&
data
.
numel
()
>
0
,
"Expected a non empty 1-dimensional tensor"
);
auto
datap
=
data
.
data_ptr
<
uint8_t
>
();
const
uint8_t
jpeg_signature
[
3
]
=
{
255
,
216
,
255
};
// == "\xFF\xD8\xFF"
const
uint8_t
png_signature
[
4
]
=
{
137
,
80
,
78
,
71
};
// == "\211PNG"
if
(
memcmp
(
jpeg_signature
,
datap
,
3
)
==
0
)
{
return
decodeJPEG
(
data
);
}
else
if
(
memcmp
(
png_signature
,
datap
,
4
)
==
0
)
{
return
decodePNG
(
data
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported image file. Only jpeg and png "
,
"are currently supported."
);
}
}
torchvision/csrc/cpu/image/read_image_cpu.h
0 → 100644
View file @
6afb3496
#pragma once
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
C10_EXPORT
torch
::
Tensor
decode_image
(
const
torch
::
Tensor
&
data
);
torchvision/csrc/cpu/image/readjpeg_cpu.cpp
View file @
6afb3496
...
...
@@ -72,6 +72,13 @@ static void torch_jpeg_set_source_mgr(
}
torch
::
Tensor
decodeJPEG
(
const
torch
::
Tensor
&
data
)
{
// Check that the input tensor dtype is uint8
TORCH_CHECK
(
data
.
dtype
()
==
torch
::
kU8
,
"Expected a torch.uint8 tensor"
);
// Check that the input tensor is 1-dimensional
TORCH_CHECK
(
data
.
dim
()
==
1
&&
data
.
numel
()
>
0
,
"Expected a non empty 1-dimensional tensor"
);
struct
jpeg_decompress_struct
cinfo
;
struct
torch_jpeg_error_mgr
jerr
;
...
...
torchvision/csrc/cpu/image/readpng_cpu.cpp
View file @
6afb3496
...
...
@@ -13,6 +13,13 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
#include <png.h>
torch
::
Tensor
decodePNG
(
const
torch
::
Tensor
&
data
)
{
// Check that the input tensor dtype is uint8
TORCH_CHECK
(
data
.
dtype
()
==
torch
::
kU8
,
"Expected a torch.uint8 tensor"
);
// Check that the input tensor is 1-dimensional
TORCH_CHECK
(
data
.
dim
()
==
1
&&
data
.
numel
()
>
0
,
"Expected a non empty 1-dimensional tensor"
);
auto
png_ptr
=
png_create_read_struct
(
PNG_LIBPNG_VER_STRING
,
nullptr
,
nullptr
,
nullptr
);
TORCH_CHECK
(
png_ptr
,
"libpng read structure allocation failed!"
)
...
...
torchvision/io/image.py
View file @
6afb3496
...
...
@@ -23,23 +23,29 @@ except (ImportError, OSError):
pass
def
_read_file
(
path
:
str
)
->
torch
.
Tensor
:
if
not
os
.
path
.
isfile
(
path
):
raise
ValueError
(
"Expected a valid file path."
)
size
=
os
.
path
.
getsize
(
path
)
if
size
==
0
:
raise
ValueError
(
"Expected a non empty file."
)
data
=
torch
.
from_file
(
path
,
dtype
=
torch
.
uint8
,
size
=
size
)
return
data
def
decode_png
(
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional int8 tensor containing
input (Tensor[1]): a one dimensional
u
int8 tensor containing
the raw bytes of the PNG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
if
not
isinstance
(
input
,
torch
.
Tensor
)
or
input
.
numel
()
==
0
or
input
.
ndim
!=
1
:
# type: ignore[attr-defined]
raise
ValueError
(
"Expected a non empty 1-dimensional tensor."
)
if
not
input
.
dtype
==
torch
.
uint8
:
raise
ValueError
(
"Expected a torch.uint8 tensor."
)
output
=
torch
.
ops
.
image
.
decode_png
(
input
)
return
output
...
...
@@ -55,13 +61,7 @@ def read_png(path: str) -> torch.Tensor:
Returns:
output (Tensor[3, image_height, image_width])
"""
if
not
os
.
path
.
isfile
(
path
):
raise
ValueError
(
"Expected a valid file path."
)
size
=
os
.
path
.
getsize
(
path
)
if
size
==
0
:
raise
ValueError
(
"Expected a non empty file."
)
data
=
torch
.
from_file
(
path
,
dtype
=
torch
.
uint8
,
size
=
size
)
data
=
_read_file
(
path
)
return
decode_png
(
data
)
...
...
@@ -70,17 +70,11 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
Decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional int8 tensor containing
input (Tensor[1]): a one dimensional
u
int8 tensor containing
the raw bytes of the JPEG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
if
not
isinstance
(
input
,
torch
.
Tensor
)
or
len
(
input
)
==
0
or
input
.
ndim
!=
1
:
# type: ignore[attr-defined]
raise
ValueError
(
"Expected a non empty 1-dimensional tensor."
)
if
not
input
.
dtype
==
torch
.
uint8
:
raise
ValueError
(
"Expected a torch.uint8 tensor."
)
output
=
torch
.
ops
.
image
.
decode_jpeg
(
input
)
return
output
...
...
@@ -94,13 +88,7 @@ def read_jpeg(path: str) -> torch.Tensor:
Returns:
output (Tensor[3, image_height, image_width])
"""
if
not
os
.
path
.
isfile
(
path
):
raise
ValueError
(
"Expected a valid file path."
)
size
=
os
.
path
.
getsize
(
path
)
if
size
==
0
:
raise
ValueError
(
"Expected a non empty file."
)
data
=
torch
.
from_file
(
path
,
dtype
=
torch
.
uint8
,
size
=
size
)
data
=
_read_file
(
path
)
return
decode_jpeg
(
data
)
...
...
@@ -141,3 +129,33 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
'between 1 and 100'
)
torch
.
ops
.
image
.
write_jpeg
(
input
,
filename
,
quality
)
def
decode_image
(
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor): a one dimensional uint8 tensor containing
the raw bytes of the PNG or JPEG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
output
=
torch
.
ops
.
image
.
decode_image
(
input
)
return
output
def
read_image
(
path
:
str
)
->
torch
.
Tensor
:
"""
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the JPEG or PNG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data
=
_read_file
(
path
)
return
decode_image
(
data
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment