Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
3300692c
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3c05b9f71c82e4cdaef579cb13f363b6c1d7964d"
Unverified
Commit
3300692c
authored
Nov 03, 2021
by
Vasilis Vryniotis
Committed by
GitHub
Nov 03, 2021
Browse files
Moving the check for prototype support in all references. (#4849)
parent
dd1adb07
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
10 deletions
+10
-10
references/classification/train.py
references/classification/train.py
+2
-2
references/classification/train_quantization.py
references/classification/train_quantization.py
+2
-2
references/detection/train.py
references/detection/train.py
+2
-2
references/segmentation/train.py
references/segmentation/train.py
+2
-2
references/video_classification/train.py
references/video_classification/train.py
+2
-2
No files found.
references/classification/train.py
View file @
3300692c
...
@@ -182,6 +182,8 @@ def load_data(traindir, valdir, args):
...
@@ -182,6 +182,8 @@ def load_data(traindir, valdir, args):
def
main
(
args
):
def
main
(
args
):
if
args
.
weights
and
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
if
args
.
output_dir
:
if
args
.
output_dir
:
utils
.
mkdir
(
args
.
output_dir
)
utils
.
mkdir
(
args
.
output_dir
)
...
@@ -226,8 +228,6 @@ def main(args):
...
@@ -226,8 +228,6 @@ def main(args):
if
not
args
.
weights
:
if
not
args
.
weights
:
model
=
torchvision
.
models
.
__dict__
[
args
.
model
](
pretrained
=
args
.
pretrained
,
num_classes
=
num_classes
)
model
=
torchvision
.
models
.
__dict__
[
args
.
model
](
pretrained
=
args
.
pretrained
,
num_classes
=
num_classes
)
else
:
else
:
if
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
model
=
PM
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
num_classes
=
num_classes
)
model
=
PM
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
num_classes
=
num_classes
)
model
.
to
(
device
)
model
.
to
(
device
)
...
...
references/classification/train_quantization.py
View file @
3300692c
...
@@ -19,6 +19,8 @@ except ImportError:
...
@@ -19,6 +19,8 @@ except ImportError:
def
main
(
args
):
def
main
(
args
):
if
args
.
weights
and
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
if
args
.
output_dir
:
if
args
.
output_dir
:
utils
.
mkdir
(
args
.
output_dir
)
utils
.
mkdir
(
args
.
output_dir
)
...
@@ -55,8 +57,6 @@ def main(args):
...
@@ -55,8 +57,6 @@ def main(args):
if
not
args
.
weights
:
if
not
args
.
weights
:
model
=
torchvision
.
models
.
quantization
.
__dict__
[
args
.
model
](
pretrained
=
True
,
quantize
=
args
.
test_only
)
model
=
torchvision
.
models
.
quantization
.
__dict__
[
args
.
model
](
pretrained
=
True
,
quantize
=
args
.
test_only
)
else
:
else
:
if
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
model
=
PM
.
quantization
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
quantize
=
args
.
test_only
)
model
=
PM
.
quantization
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
quantize
=
args
.
test_only
)
model
.
to
(
device
)
model
.
to
(
device
)
...
...
references/detection/train.py
View file @
3300692c
...
@@ -148,6 +148,8 @@ def get_args_parser(add_help=True):
...
@@ -148,6 +148,8 @@ def get_args_parser(add_help=True):
def
main
(
args
):
def
main
(
args
):
if
args
.
weights
and
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
if
args
.
output_dir
:
if
args
.
output_dir
:
utils
.
mkdir
(
args
.
output_dir
)
utils
.
mkdir
(
args
.
output_dir
)
...
@@ -194,8 +196,6 @@ def main(args):
...
@@ -194,8 +196,6 @@ def main(args):
pretrained
=
args
.
pretrained
,
num_classes
=
num_classes
,
**
kwargs
pretrained
=
args
.
pretrained
,
num_classes
=
num_classes
,
**
kwargs
)
)
else
:
else
:
if
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
model
=
PM
.
detection
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
num_classes
=
num_classes
,
**
kwargs
)
model
=
PM
.
detection
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
num_classes
=
num_classes
,
**
kwargs
)
model
.
to
(
device
)
model
.
to
(
device
)
if
args
.
distributed
and
args
.
sync_bn
:
if
args
.
distributed
and
args
.
sync_bn
:
...
...
references/segmentation/train.py
View file @
3300692c
...
@@ -92,6 +92,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
...
@@ -92,6 +92,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def
main
(
args
):
def
main
(
args
):
if
args
.
weights
and
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
if
args
.
output_dir
:
if
args
.
output_dir
:
utils
.
mkdir
(
args
.
output_dir
)
utils
.
mkdir
(
args
.
output_dir
)
...
@@ -130,8 +132,6 @@ def main(args):
...
@@ -130,8 +132,6 @@ def main(args):
aux_loss
=
args
.
aux_loss
,
aux_loss
=
args
.
aux_loss
,
)
)
else
:
else
:
if
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
model
=
PM
.
segmentation
.
__dict__
[
args
.
model
](
model
=
PM
.
segmentation
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
,
num_classes
=
num_classes
,
aux_loss
=
args
.
aux_loss
weights
=
args
.
weights
,
num_classes
=
num_classes
,
aux_loss
=
args
.
aux_loss
)
)
...
...
references/video_classification/train.py
View file @
3300692c
...
@@ -99,6 +99,8 @@ def collate_fn(batch):
...
@@ -99,6 +99,8 @@ def collate_fn(batch):
def
main
(
args
):
def
main
(
args
):
if
args
.
weights
and
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
if
args
.
apex
and
amp
is
None
:
if
args
.
apex
and
amp
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
...
@@ -214,8 +216,6 @@ def main(args):
...
@@ -214,8 +216,6 @@ def main(args):
if
not
args
.
weights
:
if
not
args
.
weights
:
model
=
torchvision
.
models
.
video
.
__dict__
[
args
.
model
](
pretrained
=
args
.
pretrained
)
model
=
torchvision
.
models
.
video
.
__dict__
[
args
.
model
](
pretrained
=
args
.
pretrained
)
else
:
else
:
if
PM
is
None
:
raise
ImportError
(
"The prototype module couldn't be found. Please install the latest torchvision nightly."
)
model
=
PM
.
video
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
)
model
=
PM
.
video
.
__dict__
[
args
.
model
](
weights
=
args
.
weights
)
model
.
to
(
device
)
model
.
to
(
device
)
if
args
.
distributed
and
args
.
sync_bn
:
if
args
.
distributed
and
args
.
sync_bn
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment