Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
934ce3b8
Unverified
Commit
934ce3b8
authored
Jan 12, 2023
by
Philip Meier
Committed by
GitHub
Jan 12, 2023
Browse files
fix MNIST byte flipping (#7081)
* fix MNIST byte flipping * add test * move to utils * remove lazy import
parent
372f4fae
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
7 deletions
+32
-7
test/test_datasets_utils.py
test/test_datasets_utils.py
+20
-0
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+6
-7
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+6
-0
No files found.
test/test_datasets_utils.py
View file @
934ce3b8
...
...
@@ -7,7 +7,9 @@ import tarfile
import
zipfile
import
pytest
import
torch
import
torchvision.datasets.utils
as
utils
from
common_utils
import
assert_equal
from
torch._utils_internal
import
get_file_path_2
from
torchvision.datasets.folder
import
make_dataset
from
torchvision.datasets.utils
import
_COMPRESSED_FILE_OPENERS
...
...
@@ -215,6 +217,24 @@ class TestDatasetsUtils:
pytest
.
raises
(
ValueError
,
utils
.
verify_str_arg
,
0
,
(
"a"
,),
"arg"
)
pytest
.
raises
(
ValueError
,
utils
.
verify_str_arg
,
"b"
,
(
"a"
,),
"arg"
)
@
pytest
.
mark
.
parametrize
(
(
"dtype"
,
"actual_hex"
,
"expected_hex"
),
[
(
torch
.
uint8
,
"01 23 45 67 89 AB CD EF"
,
"01 23 45 67 89 AB CD EF"
),
(
torch
.
float16
,
"01 23 45 67 89 AB CD EF"
,
"23 01 67 45 AB 89 EF CD"
),
(
torch
.
int32
,
"01 23 45 67 89 AB CD EF"
,
"67 45 23 01 EF CD AB 89"
),
(
torch
.
float64
,
"01 23 45 67 89 AB CD EF"
,
"EF CD AB 89 67 45 23 01"
),
],
)
def
test_flip_byte_order
(
self
,
dtype
,
actual_hex
,
expected_hex
):
def
to_tensor
(
hex
):
return
torch
.
frombuffer
(
bytes
.
fromhex
(
hex
),
dtype
=
dtype
)
assert_equal
(
utils
.
_flip_byte_order
(
to_tensor
(
actual_hex
)),
to_tensor
(
expected_hex
),
)
@
pytest
.
mark
.
parametrize
(
(
"kwargs"
,
"expected_error_msg"
),
...
...
torchvision/datasets/mnist.py
View file @
934ce3b8
...
...
@@ -12,7 +12,7 @@ import numpy as np
import
torch
from
PIL
import
Image
from
.utils
import
check_integrity
,
download_and_extract_archive
,
extract_archive
,
verify_str_arg
from
.utils
import
_flip_byte_order
,
check_integrity
,
download_and_extract_archive
,
extract_archive
,
verify_str_arg
from
.vision
import
VisionDataset
...
...
@@ -519,13 +519,12 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
torch_type
=
SN3_PASCALVINCENT_TYPEMAP
[
ty
]
s
=
[
get_int
(
data
[
4
*
(
i
+
1
)
:
4
*
(
i
+
2
)])
for
i
in
range
(
nd
)]
num_bytes_per_value
=
torch
.
iinfo
(
torch_type
).
bits
//
8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal
=
sys
.
byteorder
==
"little"
and
num_bytes_per_value
>
1
parsed
=
torch
.
frombuffer
(
bytearray
(
data
),
dtype
=
torch_type
,
offset
=
(
4
*
(
nd
+
1
)))
if
needs_byte_reversal
:
parsed
=
parsed
.
flip
(
0
)
# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
# that is little endian and the dtype has more than one byte, we need to flip them.
if
sys
.
byteorder
==
"little"
and
parsed
.
element_size
()
>
1
:
parsed
=
_flip_byte_order
(
parsed
)
assert
parsed
.
shape
[
0
]
==
np
.
prod
(
s
)
or
not
strict
return
parsed
.
view
(
*
s
)
...
...
torchvision/datasets/utils.py
View file @
934ce3b8
...
...
@@ -520,3 +520,9 @@ def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
data
=
np
.
flip
(
data
,
axis
=
1
)
# flip on h dimension
data
=
data
[:
slice_channels
,
:,
:]
return
data
.
astype
(
np
.
float32
)
def
_flip_byte_order
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
t
.
contiguous
().
view
(
torch
.
uint8
).
view
(
*
t
.
shape
,
t
.
element_size
()).
flip
(
-
1
).
view
(
*
t
.
shape
[:
-
1
],
-
1
).
view
(
t
.
dtype
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment