Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
15b97d4b
Unverified
Commit
15b97d4b
authored
Sep 01, 2022
by
vfdev
Committed by
GitHub
Sep 01, 2022
Browse files
[proto] Fix BC for inplace arg in Normalize and RandomErasing (#6530)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
11304cb7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
6 deletions
+13
-6
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+7
-1
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+0
-2
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+3
-1
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+3
-2
No files found.
test/test_prototype_transforms.py
View file @
15b97d4b
...
@@ -1004,7 +1004,13 @@ class TestRandomErasing:
...
@@ -1004,7 +1004,13 @@ class TestRandomErasing:
if
p
:
if
p
:
mock
.
assert_called_once_with
(
mock
.
assert_called_once_with
(
inpt_sentinel
,
i
=
i_sentinel
,
j
=
j_sentinel
,
h
=
h_sentinel
,
w
=
w_sentinel
,
v
=
v_sentinel
inpt_sentinel
,
i
=
i_sentinel
,
j
=
j_sentinel
,
h
=
h_sentinel
,
w
=
w_sentinel
,
v
=
v_sentinel
,
inplace
=
transform
.
inplace
,
)
)
else
:
else
:
mock
.
assert_not_called
()
mock
.
assert_not_called
()
...
...
test/test_prototype_transforms_consistency.py
View file @
15b97d4b
...
@@ -88,7 +88,6 @@ CONSISTENCY_CONFIGS = [
...
@@ -88,7 +88,6 @@ CONSISTENCY_CONFIGS = [
],
],
supports_pil
=
False
,
supports_pil
=
False
,
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
dtypes
=
[
torch
.
float
]),
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
dtypes
=
[
torch
.
float
]),
removed_params
=
[
"inplace"
],
),
),
ConsistencyConfig
(
ConsistencyConfig
(
prototype_transforms
.
Resize
,
prototype_transforms
.
Resize
,
...
@@ -315,7 +314,6 @@ CONSISTENCY_CONFIGS = [
...
@@ -315,7 +314,6 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
(
p
=
1
,
value
=
"random"
),
ArgsKwargs
(
p
=
1
,
value
=
"random"
),
],
],
supports_pil
=
False
,
supports_pil
=
False
,
removed_params
=
[
"inplace"
],
),
),
ConsistencyConfig
(
ConsistencyConfig
(
prototype_transforms
.
ColorJitter
,
prototype_transforms
.
ColorJitter
,
...
...
torchvision/prototype/transforms/_augment.py
View file @
15b97d4b
...
@@ -23,6 +23,7 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -23,6 +23,7 @@ class RandomErasing(_RandomApplyTransform):
scale
:
Tuple
[
float
,
float
]
=
(
0.02
,
0.33
),
scale
:
Tuple
[
float
,
float
]
=
(
0.02
,
0.33
),
ratio
:
Tuple
[
float
,
float
]
=
(
0.3
,
3.3
),
ratio
:
Tuple
[
float
,
float
]
=
(
0.3
,
3.3
),
value
:
float
=
0
,
value
:
float
=
0
,
inplace
:
bool
=
False
,
):
):
super
().
__init__
(
p
=
p
)
super
().
__init__
(
p
=
p
)
if
not
isinstance
(
value
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
if
not
isinstance
(
value
,
(
numbers
.
Number
,
str
,
tuple
,
list
)):
...
@@ -40,6 +41,7 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -40,6 +41,7 @@ class RandomErasing(_RandomApplyTransform):
self
.
scale
=
scale
self
.
scale
=
scale
self
.
ratio
=
ratio
self
.
ratio
=
ratio
self
.
value
=
value
self
.
value
=
value
self
.
inplace
=
inplace
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
img_c
,
img_h
,
img_w
=
query_chw
(
sample
)
img_c
,
img_h
,
img_w
=
query_chw
(
sample
)
...
@@ -92,7 +94,7 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -92,7 +94,7 @@ class RandomErasing(_RandomApplyTransform):
self
,
inpt
:
Union
[
torch
.
Tensor
,
features
.
Image
,
PIL
.
Image
.
Image
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
torch
.
Tensor
,
features
.
Image
,
PIL
.
Image
.
Image
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
torch
.
Tensor
,
features
.
Image
,
PIL
.
Image
.
Image
]:
)
->
Union
[
torch
.
Tensor
,
features
.
Image
,
PIL
.
Image
.
Image
]:
if
params
[
"v"
]
is
not
None
:
if
params
[
"v"
]
is
not
None
:
inpt
=
F
.
erase
(
inpt
,
**
params
)
inpt
=
F
.
erase
(
inpt
,
**
params
,
inplace
=
self
.
inplace
)
return
inpt
return
inpt
...
...
torchvision/prototype/transforms/_misc.py
View file @
15b97d4b
...
@@ -95,13 +95,14 @@ class LinearTransformation(Transform):
...
@@ -95,13 +95,14 @@ class LinearTransformation(Transform):
class
Normalize
(
Transform
):
class
Normalize
(
Transform
):
_transformed_types
=
(
features
.
Image
,
features
.
is_simple_tensor
)
_transformed_types
=
(
features
.
Image
,
features
.
is_simple_tensor
)
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
]):
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
]
,
inplace
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
mean
=
list
(
mean
)
self
.
mean
=
list
(
mean
)
self
.
std
=
list
(
std
)
self
.
std
=
list
(
std
)
self
.
inplace
=
inplace
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
features
.
_Feature
],
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
features
.
_Feature
],
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
)
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
def
forward
(
self
,
*
inpts
:
Any
)
->
Any
:
def
forward
(
self
,
*
inpts
:
Any
)
->
Any
:
if
has_any
(
inpts
,
PIL
.
Image
.
Image
):
if
has_any
(
inpts
,
PIL
.
Image
.
Image
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment