Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a6bed3cf
Unverified
Commit
a6bed3cf
authored
Dec 27, 2021
by
J-shang
Committed by
GitHub
Dec 27, 2021
Browse files
[Bugbash] update example import path (#4423)
parent
dbf842a6
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
46 additions
and
34 deletions
+46
-34
examples/model_compress/pruning/amc/amc_train.py
examples/model_compress/pruning/amc/amc_train.py
+2
-1
examples/model_compress/pruning/auto_pruners_torch.py
examples/model_compress/pruning/auto_pruners_torch.py
+4
-2
examples/model_compress/pruning/basic_pruners_torch.py
examples/model_compress/pruning/basic_pruners_torch.py
+2
-1
examples/model_compress/pruning/finetune_kd_torch.py
examples/model_compress/pruning/finetune_kd_torch.py
+5
-9
examples/model_compress/pruning/mobilenetv2_end2end/utils.py
examples/model_compress/pruning/mobilenetv2_end2end/utils.py
+3
-2
examples/model_compress/pruning/speedup/model_speedup.py
examples/model_compress/pruning/speedup/model_speedup.py
+3
-5
examples/model_compress/pruning/v2/activation_pruning_torch.py
...les/model_compress/pruning/v2/activation_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/admm_pruning_torch.py
examples/model_compress/pruning/v2/admm_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/amc_pruning_torch.py
examples/model_compress/pruning/v2/amc_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/auto_compress_pruner.py
examples/model_compress/pruning/v2/auto_compress_pruner.py
+2
-1
examples/model_compress/pruning/v2/fpgm_pruning_torch.py
examples/model_compress/pruning/v2/fpgm_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/iterative_pruning_torch.py
...ples/model_compress/pruning/v2/iterative_pruning_torch.py
+3
-2
examples/model_compress/pruning/v2/level_pruning_torch.py
examples/model_compress/pruning/v2/level_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/norm_pruning_torch.py
examples/model_compress/pruning/v2/norm_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/scheduler_torch.py
examples/model_compress/pruning/v2/scheduler_torch.py
+2
-1
examples/model_compress/pruning/v2/simple_pruning_torch.py
examples/model_compress/pruning/v2/simple_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/simulated_anealing_pruning_torch.py
...l_compress/pruning/v2/simulated_anealing_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/slim_pruning_torch.py
examples/model_compress/pruning/v2/slim_pruning_torch.py
+2
-1
examples/model_compress/pruning/v2/taylorfo_pruning_torch.py
examples/model_compress/pruning/v2/taylorfo_pruning_torch.py
+2
-1
No files found.
examples/model_compress/pruning/amc/amc_train.py
View file @
a6bed3cf
...
@@ -22,7 +22,8 @@ from nni.compression.pytorch import ModelSpeedup
...
@@ -22,7 +22,8 @@ from nni.compression.pytorch import ModelSpeedup
from
data
import
get_dataset
from
data
import
get_dataset
from
utils
import
AverageMeter
,
accuracy
,
progress_bar
from
utils
import
AverageMeter
,
accuracy
,
progress_bar
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
mobilenet
import
MobileNet
from
mobilenet
import
MobileNet
from
mobilenet_v2
import
MobileNetV2
from
mobilenet_v2
import
MobileNetV2
...
...
examples/model_compress/pruning/auto_pruners_torch.py
View file @
a6bed3cf
...
@@ -8,6 +8,7 @@ In this example, we present the usage of automatic pruners (NetAdapt, AutoCompre
...
@@ -8,6 +8,7 @@ In this example, we present the usage of automatic pruners (NetAdapt, AutoCompre
import
argparse
import
argparse
import
os
import
os
import
sys
import
json
import
json
import
torch
import
torch
from
torch.optim.lr_scheduler
import
StepLR
,
MultiStepLR
from
torch.optim.lr_scheduler
import
StepLR
,
MultiStepLR
...
@@ -18,12 +19,13 @@ from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner,
...
@@ -18,12 +19,13 @@ from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner,
from
nni.compression.pytorch
import
ModelSpeedup
from
nni.compression.pytorch
import
ModelSpeedup
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.compression.pytorch.utils.counter
import
count_flops_params
import
sys
from
pathlib
import
Path
sys
.
path
.
append
(
'../
models'
)
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
1
]
/
'
models'
)
)
from
mnist.lenet
import
LeNet
from
mnist.lenet
import
LeNet
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
from
cifar10.resnet
import
ResNet18
,
ResNet50
from
cifar10.resnet
import
ResNet18
,
ResNet50
def
get_data
(
dataset
,
data_dir
,
batch_size
,
test_batch_size
):
def
get_data
(
dataset
,
data_dir
,
batch_size
,
test_batch_size
):
'''
'''
get data
get data
...
...
examples/model_compress/pruning/basic_pruners_torch.py
View file @
a6bed3cf
...
@@ -17,7 +17,8 @@ import torch
...
@@ -17,7 +17,8 @@ import torch
from
torch.optim.lr_scheduler
import
StepLR
,
MultiStepLR
from
torch.optim.lr_scheduler
import
StepLR
,
MultiStepLR
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
sys
.
path
.
append
(
'../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
1
]
/
'models'
))
from
mnist.lenet
import
LeNet
from
mnist.lenet
import
LeNet
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
from
cifar10.resnet
import
ResNet18
from
cifar10.resnet
import
ResNet18
...
...
examples/model_compress/pruning/finetune_kd_torch.py
View file @
a6bed3cf
...
@@ -8,23 +8,20 @@ Run basic_pruners_torch.py first to get the masks of the pruned model. Then pass
...
@@ -8,23 +8,20 @@ Run basic_pruners_torch.py first to get the masks of the pruned model. Then pass
import
argparse
import
argparse
import
os
import
os
import
time
import
sys
from
copy
import
deepcopy
from
copy
import
deepcopy
import
nni
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
nni.compression.pytorch
import
ModelSpeedup
from
nni.compression.pytorch
import
ModelSpeedup
from
torch.optim.lr_scheduler
import
MultiStepLR
,
StepLR
from
torch.optim.lr_scheduler
import
MultiStepLR
from
torchvision
import
datasets
,
transforms
from
basic_pruners_torch
import
get_data
from
basic_pruners_torch
import
get_data
import
sys
from
pathlib
import
Path
sys
.
path
.
append
(
'../models'
)
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
1
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
mnist.lenet
import
LeNet
from
mnist.lenet
import
LeNet
from
cifar10.vgg
import
VGG
class
DistillKL
(
nn
.
Module
):
class
DistillKL
(
nn
.
Module
):
"""Distilling the Knowledge in a Neural Network"""
"""Distilling the Knowledge in a Neural Network"""
...
@@ -73,7 +70,6 @@ def get_model_optimizer_scheduler(args, device, test_loader, criterion):
...
@@ -73,7 +70,6 @@ def get_model_optimizer_scheduler(args, device, test_loader, criterion):
m_speedup
=
ModelSpeedup
(
model_s
,
dummy_input
,
args
.
mask_path
,
device
)
m_speedup
=
ModelSpeedup
(
model_s
,
dummy_input
,
args
.
mask_path
,
device
)
m_speedup
.
speedup_model
()
m_speedup
.
speedup_model
()
module_list
=
nn
.
ModuleList
([])
module_list
=
nn
.
ModuleList
([])
module_list
.
append
(
model_s
)
module_list
.
append
(
model_s
)
module_list
.
append
(
model_t
)
module_list
.
append
(
model_t
)
...
...
examples/model_compress/pruning/mobilenetv2_end2end/utils.py
View file @
a6bed3cf
...
@@ -2,14 +2,15 @@
...
@@ -2,14 +2,15 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
os
import
os
import
sys
import
torch
import
torch
from
torch.utils.data
import
Dataset
,
DataLoader
from
torch.utils.data
import
Dataset
,
DataLoader
import
torchvision.transforms
as
transforms
import
torchvision.transforms
as
transforms
import
numpy
as
np
import
numpy
as
np
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.compression.pytorch.utils.counter
import
count_flops_params
import
sys
from
pathlib
import
Path
sys
.
path
.
append
(
'../../
models'
)
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'
models'
)
)
from
mobilenet
import
MobileNet
from
mobilenet
import
MobileNet
from
mobilenet_v2
import
MobileNetV2
from
mobilenet_v2
import
MobileNetV2
...
...
examples/model_compress/pruning/speedup/model_speedup.py
View file @
a6bed3cf
import
os
import
os
import
sys
import
argparse
import
argparse
import
time
import
time
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
import
sys
from
pathlib
import
Path
sys
.
path
.
append
(
'../
models'
)
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
1
]
/
'
models'
)
)
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
from
mnist.lenet
import
LeNet
from
mnist.lenet
import
LeNet
...
...
examples/model_compress/pruning/v2/activation_pruning_torch.py
View file @
a6bed3cf
...
@@ -19,7 +19,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
...
@@ -19,7 +19,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
ActivationAPoZRankPruner
,
ActivationMeanRankPruner
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
ActivationAPoZRankPruner
,
ActivationMeanRankPruner
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/admm_pruning_torch.py
View file @
a6bed3cf
...
@@ -18,7 +18,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
...
@@ -18,7 +18,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
ADMMPruner
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
ADMMPruner
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/amc_pruning_torch.py
View file @
a6bed3cf
...
@@ -8,7 +8,8 @@ from torch.optim.lr_scheduler import MultiStepLR
...
@@ -8,7 +8,8 @@ from torch.optim.lr_scheduler import MultiStepLR
from
nni.algorithms.compression.v2.pytorch.pruning
import
AMCPruner
from
nni.algorithms.compression.v2.pytorch.pruning
import
AMCPruner
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.compression.pytorch.utils.counter
import
count_flops_params
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
...
...
examples/model_compress/pruning/v2/auto_compress_pruner.py
View file @
a6bed3cf
...
@@ -7,7 +7,8 @@ from torchvision import datasets, transforms
...
@@ -7,7 +7,8 @@ from torchvision import datasets, transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
AutoCompressPruner
from
nni.algorithms.compression.v2.pytorch.pruning
import
AutoCompressPruner
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/fpgm_pruning_torch.py
View file @
a6bed3cf
...
@@ -18,7 +18,8 @@ from nni.compression.pytorch import ModelSpeedup
...
@@ -18,7 +18,8 @@ from nni.compression.pytorch import ModelSpeedup
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
FPGMPruner
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
FPGMPruner
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/iterative_pruning_torch.py
View file @
a6bed3cf
...
@@ -19,7 +19,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
...
@@ -19,7 +19,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
LotteryTicketPruner
LotteryTicketPruner
)
)
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
...
...
examples/model_compress/pruning/v2/level_pruning_torch.py
View file @
a6bed3cf
...
@@ -17,7 +17,8 @@ from torch.optim.lr_scheduler import MultiStepLR
...
@@ -17,7 +17,8 @@ from torch.optim.lr_scheduler import MultiStepLR
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
LevelPruner
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
LevelPruner
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/norm_pruning_torch.py
View file @
a6bed3cf
...
@@ -18,7 +18,8 @@ from nni.compression.pytorch import ModelSpeedup
...
@@ -18,7 +18,8 @@ from nni.compression.pytorch import ModelSpeedup
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.compression.pytorch.utils.counter
import
count_flops_params
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
L1NormPruner
,
L2NormPruner
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
L1NormPruner
,
L2NormPruner
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/scheduler_torch.py
View file @
a6bed3cf
...
@@ -8,7 +8,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
...
@@ -8,7 +8,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
AGPTaskGenerator
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
AGPTaskGenerator
from
nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler
import
PruningScheduler
from
nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler
import
PruningScheduler
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
...
...
examples/model_compress/pruning/v2/simple_pruning_torch.py
View file @
a6bed3cf
...
@@ -7,7 +7,8 @@ from torchvision import datasets, transforms
...
@@ -7,7 +7,8 @@ from torchvision import datasets, transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
L1NormPruner
from
nni.algorithms.compression.v2.pytorch.pruning
import
L1NormPruner
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
nni.compression.pytorch.speedup
import
ModelSpeedup
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
...
...
examples/model_compress/pruning/v2/simulated_anealing_pruning_torch.py
View file @
a6bed3cf
...
@@ -15,7 +15,8 @@ from torchvision import datasets, transforms
...
@@ -15,7 +15,8 @@ from torchvision import datasets, transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
SimulatedAnnealingPruner
from
nni.algorithms.compression.v2.pytorch.pruning
import
SimulatedAnnealingPruner
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
...
...
examples/model_compress/pruning/v2/slim_pruning_torch.py
View file @
a6bed3cf
...
@@ -19,7 +19,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
...
@@ -19,7 +19,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
SlimPruner
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
SlimPruner
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/taylorfo_pruning_torch.py
View file @
a6bed3cf
...
@@ -19,7 +19,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
...
@@ -19,7 +19,8 @@ from nni.compression.pytorch.utils.counter import count_flops_params
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
TaylorFOWeightPruner
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
TaylorFOWeightPruner
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
from
nni.algorithms.compression.v2.pytorch.utils
import
trace_parameters
sys
.
path
.
append
(
'../../models'
)
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
absolute
().
parents
[
2
]
/
'models'
))
from
cifar10.vgg
import
VGG
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment