Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
61a52b93
Unverified
Commit
61a52b93
authored
Feb 02, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Feb 02, 2022
Browse files
Add --prototype flag to quantization scripts. (#5334)
parent
9d7177fe
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
5 deletions
+13
-5
references/classification/train_quantization.py
references/classification/train_quantization.py
+13
-5
No files found.
references/classification/train_quantization.py
View file @
61a52b93
...
@@ -13,14 +13,16 @@ from train import train_one_epoch, evaluate, load_data
...
@@ -13,14 +13,16 @@ from train import train_one_epoch, evaluate, load_data
try
:
try
:
from
torchvision
.prototype
import
models
as
PM
from
torchvision
import
prototype
except
ImportError
:
except
ImportError
:
PM
=
None
prototype
=
None
def
main
(
args
):
def
main
(
args
):
if
args
.
weights
and
PM
is
None
:
if
args
.
prototype
and
prototype
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
if
not
args
.
prototype
and
args
.
weights
:
raise
ValueError
(
"The weights parameter works only in prototype mode. Please pass the --prototype argument."
)
if
args
.
output_dir
:
if
args
.
output_dir
:
utils
.
mkdir
(
args
.
output_dir
)
utils
.
mkdir
(
args
.
output_dir
)
...
@@ -54,10 +56,10 @@ def main(args):
...
@@ -54,10 +56,10 @@ def main(args):
print
(
"Creating model"
,
args
.
model
)
print
(
"Creating model"
,
args
.
model
)
# when training quantized models, we always start from a pre-trained fp32 reference model
# when training quantized models, we always start from a pre-trained fp32 reference model
if
not
args
.
weights
:
if
not
args
.
prototype
:
model
=
torchvision
.
models
.
quantization
.
__dict__
[
args
.
model
](
pretrained
=
True
,
quantize
=
args
.
test_only
)
model
=
torchvision
.
models
.
quantization
.
__dict__
[
args
.
model
](
pretrained
=
True
,
quantize
=
args
.
test_only
)
else
:
else
:
model
=
PM
.
quantization
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
quantize
=
args
.
test_only
)
model
=
prototype
.
models
.
quantization
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
quantize
=
args
.
test_only
)
model
.
to
(
device
)
model
.
to
(
device
)
if
not
(
args
.
test_only
or
args
.
post_training_quantize
):
if
not
(
args
.
test_only
or
args
.
post_training_quantize
):
...
@@ -264,6 +266,12 @@ def get_args_parser(add_help=True):
...
@@ -264,6 +266,12 @@ def get_args_parser(add_help=True):
parser
.
add_argument
(
"--clip-grad-norm"
,
default
=
None
,
type
=
float
,
help
=
"the maximum gradient norm (default None)"
)
parser
.
add_argument
(
"--clip-grad-norm"
,
default
=
None
,
type
=
float
,
help
=
"the maximum gradient norm (default None)"
)
# Prototype models only
# Prototype models only
parser
.
add_argument
(
"--prototype"
,
dest
=
"prototype"
,
help
=
"Use prototype model builders instead those from main area"
,
action
=
"store_true"
,
)
parser
.
add_argument
(
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"the weights enum name to load"
)
parser
.
add_argument
(
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"the weights enum name to load"
)
return
parser
return
parser
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment