Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b122c63d
Unverified
Commit
b122c63d
authored
Mar 03, 2021
by
colorjam
Committed by
GitHub
Mar 03, 2021
Browse files
Update finetuning with kd example (#3412)
parent
969f0d99
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
11 deletions
+14
-11
examples/model_compress/pruning/finetune_kd_torch.py
examples/model_compress/pruning/finetune_kd_torch.py
+14
-11
No files found.
examples/model_compress/pruning/finetune_kd_torch.py
View file @
b122c63d
...
@@ -7,24 +7,21 @@ Run basic_pruners_torch.py first to get the masks of the pruned model. Then pass
...
@@ -7,24 +7,21 @@ Run basic_pruners_torch.py first to get the masks of the pruned model. Then pass
'''
'''
import
argparse
import
argparse
import
os
import
os
import
time
import
time
import
argparse
from
copy
import
deepcopy
import
nni
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
from
torch.optim.lr_scheduler
import
StepLR
,
MultiStepLR
from
nni.compression.pytorch
import
ModelSpeedup
from
torch.optim.lr_scheduler
import
MultiStepLR
,
StepLR
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
copy
import
deepcopy
from
models.mnist.lenet
import
LeNet
from
models.cifar10.vgg
import
VGG
from
basic_pruners_torch
import
get_data
from
basic_pruners_torch
import
get_data
from
models.cifar10.vgg
import
VGG
import
nni
from
models.mnist.lenet
import
LeNet
from
nni.compression.pytorch
import
ModelSpeedup
,
get_dummy_input
class
DistillKL
(
nn
.
Module
):
class
DistillKL
(
nn
.
Module
):
"""Distilling the Knowledge in a Neural Network"""
"""Distilling the Knowledge in a Neural Network"""
...
@@ -38,6 +35,13 @@ class DistillKL(nn.Module):
...
@@ -38,6 +35,13 @@ class DistillKL(nn.Module):
loss
=
F
.
kl_div
(
p_s
,
p_t
,
size_average
=
False
)
*
(
self
.
T
**
2
)
/
y_s
.
shape
[
0
]
loss
=
F
.
kl_div
(
p_s
,
p_t
,
size_average
=
False
)
*
(
self
.
T
**
2
)
/
y_s
.
shape
[
0
]
return
loss
return
loss
def
get_dummy_input
(
args
,
device
):
if
args
.
dataset
==
'mnist'
:
dummy_input
=
torch
.
randn
([
args
.
test_batch_size
,
1
,
28
,
28
]).
to
(
device
)
elif
args
.
dataset
in
[
'cifar10'
,
'imagenet'
]:
dummy_input
=
torch
.
randn
([
args
.
test_batch_size
,
3
,
32
,
32
]).
to
(
device
)
return
dummy_input
def
get_model_optimizer_scheduler
(
args
,
device
,
test_loader
,
criterion
):
def
get_model_optimizer_scheduler
(
args
,
device
,
test_loader
,
criterion
):
if
args
.
model
==
'LeNet'
:
if
args
.
model
==
'LeNet'
:
model
=
LeNet
().
to
(
device
)
model
=
LeNet
().
to
(
device
)
...
@@ -51,7 +55,6 @@ def get_model_optimizer_scheduler(args, device, test_loader, criterion):
...
@@ -51,7 +55,6 @@ def get_model_optimizer_scheduler(args, device, test_loader, criterion):
# In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture.
# In this example, we set the architecture of teacher and student to be the same. It is feasible to set a different teacher architecture.
if
args
.
teacher_model_dir
is
None
:
if
args
.
teacher_model_dir
is
None
:
raise
NotImplementedError
(
'please load pretrained teacher model first'
)
raise
NotImplementedError
(
'please load pretrained teacher model first'
)
else
:
else
:
model
.
load_state_dict
(
torch
.
load
(
args
.
teacher_model_dir
))
model
.
load_state_dict
(
torch
.
load
(
args
.
teacher_model_dir
))
best_acc
=
test
(
args
,
model
,
device
,
criterion
,
test_loader
)
best_acc
=
test
(
args
,
model
,
device
,
criterion
,
test_loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment