Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
0ab7d05c
Unverified
Commit
0ab7d05c
authored
May 31, 2023
by
Nicolas Hug
Committed by
GitHub
May 31, 2023
Browse files
Allow classification references to use the tensor backend (#7629)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
a6f63879
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
13 deletions
+46
-13
references/classification/presets.py
references/classification/presets.py
+35
-11
references/classification/train.py
references/classification/train.py
+11
-2
No files found.
references/classification/presets.py
View file @
0ab7d05c
...
@@ -16,8 +16,16 @@ class ClassificationPresetTrain:
...
@@ -16,8 +16,16 @@ class ClassificationPresetTrain:
ra_magnitude
=
9
,
ra_magnitude
=
9
,
augmix_severity
=
3
,
augmix_severity
=
3
,
random_erase_prob
=
0.0
,
random_erase_prob
=
0.0
,
backend
=
"pil"
,
):
):
trans
=
[
transforms
.
RandomResizedCrop
(
crop_size
,
interpolation
=
interpolation
)]
trans
=
[]
backend
=
backend
.
lower
()
if
backend
==
"tensor"
:
trans
.
append
(
transforms
.
PILToTensor
())
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
trans
.
append
(
transforms
.
RandomResizedCrop
(
crop_size
,
interpolation
=
interpolation
,
antialias
=
True
))
if
hflip_prob
>
0
:
if
hflip_prob
>
0
:
trans
.
append
(
transforms
.
RandomHorizontalFlip
(
hflip_prob
))
trans
.
append
(
transforms
.
RandomHorizontalFlip
(
hflip_prob
))
if
auto_augment_policy
is
not
None
:
if
auto_augment_policy
is
not
None
:
...
@@ -30,9 +38,12 @@ class ClassificationPresetTrain:
...
@@ -30,9 +38,12 @@ class ClassificationPresetTrain:
else
:
else
:
aa_policy
=
autoaugment
.
AutoAugmentPolicy
(
auto_augment_policy
)
aa_policy
=
autoaugment
.
AutoAugmentPolicy
(
auto_augment_policy
)
trans
.
append
(
autoaugment
.
AutoAugment
(
policy
=
aa_policy
,
interpolation
=
interpolation
))
trans
.
append
(
autoaugment
.
AutoAugment
(
policy
=
aa_policy
,
interpolation
=
interpolation
))
if
backend
==
"pil"
:
trans
.
append
(
transforms
.
PILToTensor
())
trans
.
extend
(
trans
.
extend
(
[
[
transforms
.
PILToTensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
Normalize
(
mean
=
mean
,
std
=
std
),
transforms
.
Normalize
(
mean
=
mean
,
std
=
std
),
]
]
...
@@ -55,17 +66,30 @@ class ClassificationPresetEval:
...
@@ -55,17 +66,30 @@ class ClassificationPresetEval:
mean
=
(
0.485
,
0.456
,
0.406
),
mean
=
(
0.485
,
0.456
,
0.406
),
std
=
(
0.229
,
0.224
,
0.225
),
std
=
(
0.229
,
0.224
,
0.225
),
interpolation
=
InterpolationMode
.
BILINEAR
,
interpolation
=
InterpolationMode
.
BILINEAR
,
backend
=
"pil"
,
):
):
trans
=
[]
self
.
transforms
=
transforms
.
Compose
(
backend
=
backend
.
lower
()
[
if
backend
==
"tensor"
:
transforms
.
Resize
(
resize_size
,
interpolation
=
interpolation
),
trans
.
append
(
transforms
.
PILToTensor
())
transforms
.
CenterCrop
(
crop_size
),
else
:
transforms
.
PILToTensor
(),
raise
ValueError
(
f
"backend can be 'tensor' or 'pil', but got
{
backend
}
"
)
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
Normalize
(
mean
=
mean
,
std
=
std
),
trans
+=
[
]
transforms
.
Resize
(
resize_size
,
interpolation
=
interpolation
,
antialias
=
True
),
)
transforms
.
CenterCrop
(
crop_size
),
]
if
backend
==
"pil"
:
trans
.
append
(
transforms
.
PILToTensor
())
trans
+=
[
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
Normalize
(
mean
=
mean
,
std
=
std
),
]
self
.
transforms
=
transforms
.
Compose
(
trans
)
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
return
self
.
transforms
(
img
)
return
self
.
transforms
(
img
)
references/classification/train.py
View file @
0ab7d05c
...
@@ -7,6 +7,7 @@ import presets
...
@@ -7,6 +7,7 @@ import presets
import
torch
import
torch
import
torch.utils.data
import
torch.utils.data
import
torchvision
import
torchvision
import
torchvision.transforms
import
transforms
import
transforms
import
utils
import
utils
from
sampler
import
RASampler
from
sampler
import
RASampler
...
@@ -143,6 +144,7 @@ def load_data(traindir, valdir, args):
...
@@ -143,6 +144,7 @@ def load_data(traindir, valdir, args):
random_erase_prob
=
random_erase_prob
,
random_erase_prob
=
random_erase_prob
,
ra_magnitude
=
ra_magnitude
,
ra_magnitude
=
ra_magnitude
,
augmix_severity
=
augmix_severity
,
augmix_severity
=
augmix_severity
,
backend
=
args
.
backend
,
),
),
)
)
if
args
.
cache_dataset
:
if
args
.
cache_dataset
:
...
@@ -160,10 +162,16 @@ def load_data(traindir, valdir, args):
...
@@ -160,10 +162,16 @@ def load_data(traindir, valdir, args):
else
:
else
:
if
args
.
weights
and
args
.
test_only
:
if
args
.
weights
and
args
.
test_only
:
weights
=
torchvision
.
models
.
get_weight
(
args
.
weights
)
weights
=
torchvision
.
models
.
get_weight
(
args
.
weights
)
preprocessing
=
weights
.
transforms
()
preprocessing
=
weights
.
transforms
(
antialias
=
True
)
if
args
.
backend
==
"tensor"
:
preprocessing
=
torchvision
.
transforms
.
Compose
([
torchvision
.
transforms
.
PILToTensor
(),
preprocessing
])
else
:
else
:
preprocessing
=
presets
.
ClassificationPresetEval
(
preprocessing
=
presets
.
ClassificationPresetEval
(
crop_size
=
val_crop_size
,
resize_size
=
val_resize_size
,
interpolation
=
interpolation
crop_size
=
val_crop_size
,
resize_size
=
val_resize_size
,
interpolation
=
interpolation
,
backend
=
args
.
backend
,
)
)
dataset_test
=
torchvision
.
datasets
.
ImageFolder
(
dataset_test
=
torchvision
.
datasets
.
ImageFolder
(
...
@@ -507,6 +515,7 @@ def get_args_parser(add_help=True):
...
@@ -507,6 +515,7 @@ def get_args_parser(add_help=True):
"--ra-reps"
,
default
=
3
,
type
=
int
,
help
=
"number of repetitions for Repeated Augmentation (default: 3)"
"--ra-reps"
,
default
=
3
,
type
=
int
,
help
=
"number of repetitions for Repeated Augmentation (default: 3)"
)
)
parser
.
add_argument
(
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"the weights enum name to load"
)
parser
.
add_argument
(
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"the weights enum name to load"
)
parser
.
add_argument
(
"--backend"
,
default
=
"PIL"
,
type
=
str
.
lower
,
help
=
"PIL or tensor - case insensitive"
)
return
parser
return
parser
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment