Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
08c9938f
Unverified
Commit
08c9938f
authored
Jul 07, 2023
by
Nicolas Hug
Committed by
GitHub
Jul 07, 2023
Browse files
Add --use-v2 support to classification references (#7724)
parent
23b0938f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
26 deletions
+45
-26
references/classification/presets.py
references/classification/presets.py
+42
-26
references/classification/train.py
references/classification/train.py
+3
-0
No files found.
references/classification/presets.py
View file @
08c9938f
import
torch
import
torch
from
torchvision.transforms
import
autoaugment
,
transforms
from
torchvision.transforms.functional
import
InterpolationMode
from
torchvision.transforms.functional
import
InterpolationMode
def
get_module
(
use_v2
):
# We need a protected import to avoid the V2 warning in case just V1 is used
if
use_v2
:
import
torchvision.transforms.v2
return
torchvision
.
transforms
.
v2
else
:
import
torchvision.transforms
return
torchvision
.
transforms
class
ClassificationPresetTrain
:
class
ClassificationPresetTrain
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -17,41 +28,44 @@ class ClassificationPresetTrain:
...
@@ -17,41 +28,44 @@ class ClassificationPresetTrain:
augmix_severity
=
3
,
augmix_severity
=
3
,
random_erase_prob
=
0.0
,
random_erase_prob
=
0.0
,
backend
=
"pil"
,
backend
=
"pil"
,
use_v2
=
False
,
):
):
trans
=
[]
module
=
get_module
(
use_v2
)
transforms
=
[]
backend
=
backend
.
lower
()
backend
=
backend
.
lower
()
if
backend
==
"tensor"
:
if
backend
==
"tensor"
:
trans
.
append
(
transforms
.
PILToTensor
())
trans
forms
.
append
(
module
.
PILToTensor
())
elif
backend
!=
"pil"
:
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
trans
.
append
(
transforms
.
RandomResizedCrop
(
crop_size
,
interpolation
=
interpolation
,
antialias
=
True
))
trans
forms
.
append
(
module
.
RandomResizedCrop
(
crop_size
,
interpolation
=
interpolation
,
antialias
=
True
))
if
hflip_prob
>
0
:
if
hflip_prob
>
0
:
trans
.
append
(
transforms
.
RandomHorizontalFlip
(
hflip_prob
))
trans
forms
.
append
(
module
.
RandomHorizontalFlip
(
hflip_prob
))
if
auto_augment_policy
is
not
None
:
if
auto_augment_policy
is
not
None
:
if
auto_augment_policy
==
"ra"
:
if
auto_augment_policy
==
"ra"
:
trans
.
append
(
autoaugment
.
RandAugment
(
interpolation
=
interpolation
,
magnitude
=
ra_magnitude
))
trans
forms
.
append
(
module
.
RandAugment
(
interpolation
=
interpolation
,
magnitude
=
ra_magnitude
))
elif
auto_augment_policy
==
"ta_wide"
:
elif
auto_augment_policy
==
"ta_wide"
:
trans
.
append
(
autoaugment
.
TrivialAugmentWide
(
interpolation
=
interpolation
))
trans
forms
.
append
(
module
.
TrivialAugmentWide
(
interpolation
=
interpolation
))
elif
auto_augment_policy
==
"augmix"
:
elif
auto_augment_policy
==
"augmix"
:
trans
.
append
(
autoaugment
.
AugMix
(
interpolation
=
interpolation
,
severity
=
augmix_severity
))
trans
forms
.
append
(
module
.
AugMix
(
interpolation
=
interpolation
,
severity
=
augmix_severity
))
else
:
else
:
aa_policy
=
autoaugment
.
AutoAugmentPolicy
(
auto_augment_policy
)
aa_policy
=
module
.
AutoAugmentPolicy
(
auto_augment_policy
)
trans
.
append
(
autoaugment
.
AutoAugment
(
policy
=
aa_policy
,
interpolation
=
interpolation
))
trans
forms
.
append
(
module
.
AutoAugment
(
policy
=
aa_policy
,
interpolation
=
interpolation
))
if
backend
==
"pil"
:
if
backend
==
"pil"
:
trans
.
append
(
transforms
.
PILToTensor
())
trans
forms
.
append
(
module
.
PILToTensor
())
trans
.
extend
(
trans
forms
.
extend
(
[
[
transforms
.
ConvertImageDtype
(
torch
.
float
),
module
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
Normalize
(
mean
=
mean
,
std
=
std
),
module
.
Normalize
(
mean
=
mean
,
std
=
std
),
]
]
)
)
if
random_erase_prob
>
0
:
if
random_erase_prob
>
0
:
trans
.
append
(
transforms
.
RandomErasing
(
p
=
random_erase_prob
))
trans
forms
.
append
(
module
.
RandomErasing
(
p
=
random_erase_prob
))
self
.
transforms
=
transforms
.
Compose
(
trans
)
self
.
transforms
=
module
.
Compose
(
trans
forms
)
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
return
self
.
transforms
(
img
)
return
self
.
transforms
(
img
)
...
@@ -67,28 +81,30 @@ class ClassificationPresetEval:
...
@@ -67,28 +81,30 @@ class ClassificationPresetEval:
std
=
(
0.229
,
0.224
,
0.225
),
std
=
(
0.229
,
0.224
,
0.225
),
interpolation
=
InterpolationMode
.
BILINEAR
,
interpolation
=
InterpolationMode
.
BILINEAR
,
backend
=
"pil"
,
backend
=
"pil"
,
use_v2
=
False
,
):
):
trans
=
[]
module
=
get_module
(
use_v2
)
transforms
=
[]
backend
=
backend
.
lower
()
backend
=
backend
.
lower
()
if
backend
==
"tensor"
:
if
backend
==
"tensor"
:
trans
.
append
(
transforms
.
PILToTensor
())
trans
forms
.
append
(
module
.
PILToTensor
())
elif
backend
!=
"pil"
:
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
trans
+=
[
trans
forms
+=
[
transforms
.
Resize
(
resize_size
,
interpolation
=
interpolation
,
antialias
=
True
),
module
.
Resize
(
resize_size
,
interpolation
=
interpolation
,
antialias
=
True
),
transforms
.
CenterCrop
(
crop_size
),
module
.
CenterCrop
(
crop_size
),
]
]
if
backend
==
"pil"
:
if
backend
==
"pil"
:
trans
.
append
(
transforms
.
PILToTensor
())
trans
forms
.
append
(
module
.
PILToTensor
())
trans
+=
[
trans
forms
+=
[
transforms
.
ConvertImageDtype
(
torch
.
float
),
module
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
Normalize
(
mean
=
mean
,
std
=
std
),
module
.
Normalize
(
mean
=
mean
,
std
=
std
),
]
]
self
.
transforms
=
transforms
.
Compose
(
trans
)
self
.
transforms
=
module
.
Compose
(
trans
forms
)
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
return
self
.
transforms
(
img
)
return
self
.
transforms
(
img
)
references/classification/train.py
View file @
08c9938f
...
@@ -145,6 +145,7 @@ def load_data(traindir, valdir, args):
...
@@ -145,6 +145,7 @@ def load_data(traindir, valdir, args):
ra_magnitude
=
ra_magnitude
,
ra_magnitude
=
ra_magnitude
,
augmix_severity
=
augmix_severity
,
augmix_severity
=
augmix_severity
,
backend
=
args
.
backend
,
backend
=
args
.
backend
,
use_v2
=
args
.
use_v2
,
),
),
)
)
if
args
.
cache_dataset
:
if
args
.
cache_dataset
:
...
@@ -172,6 +173,7 @@ def load_data(traindir, valdir, args):
...
@@ -172,6 +173,7 @@ def load_data(traindir, valdir, args):
resize_size
=
val_resize_size
,
resize_size
=
val_resize_size
,
interpolation
=
interpolation
,
interpolation
=
interpolation
,
backend
=
args
.
backend
,
backend
=
args
.
backend
,
use_v2
=
args
.
use_v2
,
)
)
dataset_test
=
torchvision
.
datasets
.
ImageFolder
(
dataset_test
=
torchvision
.
datasets
.
ImageFolder
(
...
@@ -516,6 +518,7 @@ def get_args_parser(add_help=True):
...
@@ -516,6 +518,7 @@ def get_args_parser(add_help=True):
)
)
parser
.
add_argument
(
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"the weights enum name to load"
)
parser
.
add_argument
(
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"the weights enum name to load"
)
parser
.
add_argument
(
"--backend"
,
default
=
"PIL"
,
type
=
str
.
lower
,
help
=
"PIL or tensor - case insensitive"
)
parser
.
add_argument
(
"--backend"
,
default
=
"PIL"
,
type
=
str
.
lower
,
help
=
"PIL or tensor - case insensitive"
)
parser
.
add_argument
(
"--use-v2"
,
action
=
"store_true"
,
help
=
"Use V2 transforms"
)
return
parser
return
parser
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment