Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5a3d82e8
Unverified
Commit
5a3d82e8
authored
Jul 14, 2022
by
J-shang
Committed by
GitHub
Jul 14, 2022
Browse files
[Compression] lightning & legacy evaluator - step 1 (#4950)
parent
0a57438b
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1053 additions
and
24 deletions
+1053
-24
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+4
-5
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
...orithms/compression/v2/pytorch/pruning/movement_pruner.py
+1
-1
nni/algorithms/compression/v2/pytorch/utils/__init__.py
nni/algorithms/compression/v2/pytorch/utils/__init__.py
+13
-1
nni/algorithms/compression/v2/pytorch/utils/constructor_helper.py
...rithms/compression/v2/pytorch/utils/constructor_helper.py
+9
-12
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
+727
-0
nni/retiarii/hub/pytorch/utils.py
nni/retiarii/hub/pytorch/utils.py
+0
-5
test/algo/compression/v2/test_evaluator.py
test/algo/compression/v2/test_evaluator.py
+299
-0
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
5a3d82e8
...
...
@@ -12,7 +12,6 @@ import torch.nn.functional as F
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
nni.common.serializer
import
Traceable
from
..base
import
Pruner
from
.tools
import
(
...
...
@@ -523,7 +522,7 @@ class SlimPruner(BasicPruner):
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
traced_optimizer
:
Traceable
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
traced_optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_epochs
:
int
,
scale
:
float
=
0.0001
,
mode
=
'global'
):
self
.
mode
=
mode
self
.
trainer
=
trainer
...
...
@@ -633,7 +632,7 @@ class ActivationPruner(BasicPruner):
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
traced_optimizer
:
Traceable
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
activation
:
str
=
'relu'
,
traced_optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
activation
:
str
=
'relu'
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
self
.
mode
=
mode
self
.
dummy_input
=
dummy_input
...
...
@@ -957,7 +956,7 @@ class TaylorFOWeightPruner(BasicPruner):
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
traced_optimizer
:
Traceable
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
traced_optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
self
.
mode
=
mode
self
.
dummy_input
=
dummy_input
...
...
@@ -1099,7 +1098,7 @@ class ADMMPruner(BasicPruner):
"""
def
__init__
(
self
,
model
:
Optional
[
Module
],
config_list
:
Optional
[
List
[
Dict
]],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
traced_optimizer
:
Traceable
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
iterations
:
int
,
traced_optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
iterations
:
int
,
training_epochs
:
int
,
granularity
:
str
=
'fine-grained'
):
self
.
trainer
=
trainer
if
isinstance
(
traced_optimizer
,
OptimizerConstructHelper
):
...
...
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
View file @
5a3d82e8
...
...
@@ -161,7 +161,7 @@ class MovementPruner(BasicPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>`
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
traced_optimizer
:
Traceable
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_epochs
:
int
,
warm_up_step
:
int
,
traced_optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_epochs
:
int
,
warm_up_step
:
int
,
cool_down_beginning_step
:
int
):
self
.
trainer
=
trainer
if
isinstance
(
traced_optimizer
,
OptimizerConstructHelper
):
...
...
nni/algorithms/compression/v2/pytorch/utils/__init__.py
View file @
5a3d82e8
...
...
@@ -6,7 +6,19 @@ from .attr import (
set_nested_attr
)
from
.config_validation
import
CompressorSchema
from
.constructor_helper
import
*
from
.constructor_helper
import
(
OptimizerConstructHelper
,
LRSchedulerConstructHelper
)
from
.evaluator
import
(
Evaluator
,
LightningEvaluator
,
TorchEvaluator
,
Hook
,
BackwardHook
,
ForwardHook
,
TensorHook
)
from
.pruning
import
(
config_list_canonical
,
unfold_config_list
,
...
...
nni/algorithms/compression/v2/pytorch/utils/constructor_helper.py
View file @
5a3d82e8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
copy
import
deepcopy
from
typing
import
Callable
,
Dict
,
List
,
Type
...
...
@@ -9,7 +11,6 @@ from torch.nn import Module
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
nni.common.serializer
import
_trace_cls
from
nni.common.serializer
import
Traceable
,
is_traceable
__all__
=
[
'OptimizerConstructHelper'
,
'LRSchedulerConstructHelper'
]
...
...
@@ -60,14 +61,15 @@ class OptimizerConstructHelper(ConstructHelper):
return
param_groups
def
names2params
(
self
,
wrapped_model
:
Module
,
origin2wrapped_name_map
:
Dict
,
params
:
List
[
Dict
])
->
List
[
Dict
]:
def
names2params
(
self
,
wrapped_model
:
Module
,
origin2wrapped_name_map
:
Dict
|
None
,
params
:
List
[
Dict
])
->
List
[
Dict
]:
param_groups
=
deepcopy
(
params
)
origin2wrapped_name_map
=
origin2wrapped_name_map
if
origin2wrapped_name_map
else
{}
for
param_group
in
param_groups
:
wrapped_names
=
[
origin2wrapped_name_map
.
get
(
name
,
name
)
for
name
in
param_group
[
'params'
]]
param_group
[
'params'
]
=
[
p
for
name
,
p
in
wrapped_model
.
named_parameters
()
if
name
in
wrapped_names
]
return
param_groups
def
call
(
self
,
wrapped_model
:
Module
,
origin2wrapped_name_map
:
Dict
)
->
Optimizer
:
def
call
(
self
,
wrapped_model
:
Module
,
origin2wrapped_name_map
:
Dict
|
None
)
->
Optimizer
:
args
=
deepcopy
(
self
.
args
)
kwargs
=
deepcopy
(
self
.
kwargs
)
...
...
@@ -79,15 +81,12 @@ class OptimizerConstructHelper(ConstructHelper):
return
self
.
callable_obj
(
*
args
,
**
kwargs
)
@
staticmethod
def
from_trace
(
model
:
Module
,
optimizer_trace
:
Traceable
):
def
from_trace
(
model
:
Module
,
optimizer_trace
:
Optimizer
):
assert
is_traceable
(
optimizer_trace
),
\
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
assert
isinstance
(
optimizer_trace
,
Optimizer
),
\
'It is not an instance of torch.nn.Optimizer.'
return
OptimizerConstructHelper
(
model
,
optimizer_trace
.
trace_symbol
,
*
optimizer_trace
.
trace_args
,
**
optimizer_trace
.
trace_kwargs
)
return
OptimizerConstructHelper
(
model
,
optimizer_trace
.
trace_symbol
,
*
optimizer_trace
.
trace_args
,
**
optimizer_trace
.
trace_kwargs
)
# type: ignore
class
LRSchedulerConstructHelper
(
ConstructHelper
):
...
...
@@ -111,11 +110,9 @@ class LRSchedulerConstructHelper(ConstructHelper):
return
self
.
callable_obj
(
*
args
,
**
kwargs
)
@
staticmethod
def
from_trace
(
lr_scheduler_trace
:
Traceab
le
):
def
from_trace
(
lr_scheduler_trace
:
_LRSchedu
le
r
):
assert
is_traceable
(
lr_scheduler_trace
),
\
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert
isinstance
(
lr_scheduler_trace
,
_LRScheduler
),
\
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
return
LRSchedulerConstructHelper
(
lr_scheduler_trace
.
trace_symbol
,
*
lr_scheduler_trace
.
trace_args
,
**
lr_scheduler_trace
.
trace_kwargs
)
return
LRSchedulerConstructHelper
(
lr_scheduler_trace
.
trace_symbol
,
*
lr_scheduler_trace
.
trace_args
,
**
lr_scheduler_trace
.
trace_kwargs
)
# type: ignore
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
0 → 100644
View file @
5a3d82e8
This diff is collapsed.
Click to expand it.
nni/retiarii/hub/pytorch/utils.py
deleted
100644 → 0
View file @
0a57438b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints
test/algo/compression/v2/test_evaluator.py
0 → 100644
View file @
5a3d82e8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
pathlib
import
Path
from
typing
import
Callable
import
pytest
import
pytorch_lightning
as
pl
from
pytorch_lightning.loggers
import
TensorBoardLogger
import
torch
from
torch.nn
import
Module
import
torch.nn.functional
as
F
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
ExponentialLR
,
_LRScheduler
from
torch.utils.data
import
random_split
,
DataLoader
from
torchmetrics.functional
import
accuracy
from
torchvision.datasets
import
MNIST
from
torchvision
import
transforms
import
nni
from
nni.algorithms.compression.v2.pytorch.utils.evaluator
import
(
TorchEvaluator
,
LightningEvaluator
,
TensorHook
,
ForwardHook
,
BackwardHook
,
)
class
SimpleTorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
16
,
3
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
16
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
16
,
8
,
3
,
groups
=
4
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
8
)
self
.
conv3
=
torch
.
nn
.
Conv2d
(
16
,
8
,
3
)
self
.
bn3
=
torch
.
nn
.
BatchNorm2d
(
8
)
self
.
fc1
=
torch
.
nn
.
Linear
(
8
*
24
*
24
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
self
.
bn1
(
self
.
conv1
(
x
))
x
=
self
.
bn2
(
self
.
conv2
(
x
))
+
self
.
bn3
(
self
.
conv3
(
x
))
x
=
self
.
fc2
(
self
.
fc1
(
x
.
reshape
(
x
.
shape
[
0
],
-
1
)))
return
F
.
log_softmax
(
x
,
-
1
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
def
training_model
(
model
:
Module
,
optimizer
:
Optimizer
,
criterion
:
Callable
,
scheduler
:
_LRScheduler
,
max_steps
:
int
|
None
=
None
,
max_epochs
:
int
|
None
=
None
):
model
.
train
()
# prepare data
data_dir
=
Path
(
__file__
).
parent
/
'data'
MNIST
(
data_dir
,
train
=
True
,
download
=
True
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
mnist_train
=
MNIST
(
data_dir
,
train
=
True
,
transform
=
transform
)
train_dataloader
=
DataLoader
(
mnist_train
,
batch_size
=
32
)
max_epochs
=
max_epochs
if
max_epochs
else
1
max_steps
=
max_steps
if
max_steps
else
10
current_steps
=
0
# training
for
_
in
range
(
max_epochs
):
for
x
,
y
in
train_dataloader
:
optimizer
.
zero_grad
()
x
,
y
=
x
.
to
(
device
),
y
.
to
(
device
)
logits
=
model
(
x
)
loss
:
torch
.
Tensor
=
criterion
(
logits
,
y
)
loss
.
backward
()
optimizer
.
step
()
current_steps
+=
1
if
max_steps
and
current_steps
==
max_steps
:
return
scheduler
.
step
()
def
evaluating_model
(
model
:
Module
):
model
.
eval
()
# prepare data
data_dir
=
Path
(
__file__
).
parent
/
'data'
MNIST
(
data_dir
,
train
=
False
,
download
=
True
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
mnist_test
=
MNIST
(
data_dir
,
train
=
False
,
transform
=
transform
)
test_dataloader
=
DataLoader
(
mnist_test
,
batch_size
=
32
)
# testing
correct
=
0
with
torch
.
no_grad
():
for
x
,
y
in
test_dataloader
:
x
,
y
=
x
.
to
(
device
),
y
.
to
(
device
)
logits
=
model
(
x
)
preds
=
torch
.
argmax
(
logits
,
dim
=
1
)
correct
+=
preds
.
eq
(
y
.
view_as
(
preds
)).
sum
().
item
()
return
correct
/
len
(
mnist_test
)
class
SimpleLightningModel
(
pl
.
LightningModule
):
def
__init__
(
self
):
super
().
__init__
()
self
.
model
=
SimpleTorchModel
()
self
.
count
=
0
def
forward
(
self
,
x
):
print
(
self
.
count
)
self
.
count
+=
1
return
self
.
model
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
logits
=
self
(
x
)
loss
=
F
.
nll_loss
(
logits
,
y
)
self
.
log
(
"train_loss"
,
loss
)
return
loss
def
evaluate
(
self
,
batch
,
stage
=
None
):
x
,
y
=
batch
logits
=
self
(
x
)
loss
=
F
.
nll_loss
(
logits
,
y
)
preds
=
torch
.
argmax
(
logits
,
dim
=
1
)
acc
=
accuracy
(
preds
,
y
)
if
stage
:
self
.
log
(
f
"
{
stage
}
_loss"
,
loss
,
prog_bar
=
True
)
self
.
log
(
f
"
{
stage
}
_acc"
,
acc
,
prog_bar
=
True
)
def
validation_step
(
self
,
batch
,
batch_idx
):
self
.
evaluate
(
batch
,
"val"
)
def
test_step
(
self
,
batch
,
batch_idx
):
self
.
evaluate
(
batch
,
"test"
)
def
configure_optimizers
(
self
):
optimizer
=
nni
.
trace
(
torch
.
optim
.
SGD
)(
self
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
,
)
scheduler_dict
=
{
"scheduler"
:
nni
.
trace
(
ExponentialLR
)(
optimizer
,
0.1
,
),
"interval"
:
"epoch"
,
}
return
{
"optimizer"
:
optimizer
,
"lr_scheduler"
:
scheduler_dict
}
class
MNISTDataModule
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
data_dir
:
str
=
"./"
):
super
().
__init__
()
self
.
data_dir
=
data_dir
self
.
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
def
prepare_data
(
self
):
# download
MNIST
(
self
.
data_dir
,
train
=
True
,
download
=
True
)
MNIST
(
self
.
data_dir
,
train
=
False
,
download
=
True
)
def
setup
(
self
,
stage
:
str
|
None
=
None
):
# Assign train/val datasets for use in dataloaders
if
stage
==
"fit"
or
stage
is
None
:
mnist_full
=
MNIST
(
self
.
data_dir
,
train
=
True
,
transform
=
self
.
transform
)
self
.
mnist_train
,
self
.
mnist_val
=
random_split
(
mnist_full
,
[
55000
,
5000
])
# Assign test dataset for use in dataloader(s)
if
stage
==
"test"
or
stage
is
None
:
self
.
mnist_test
=
MNIST
(
self
.
data_dir
,
train
=
False
,
transform
=
self
.
transform
)
if
stage
==
"predict"
or
stage
is
None
:
self
.
mnist_predict
=
MNIST
(
self
.
data_dir
,
train
=
False
,
transform
=
self
.
transform
)
def
train_dataloader
(
self
):
return
DataLoader
(
self
.
mnist_train
,
batch_size
=
32
)
def
val_dataloader
(
self
):
return
DataLoader
(
self
.
mnist_val
,
batch_size
=
32
)
def
test_dataloader
(
self
):
return
DataLoader
(
self
.
mnist_test
,
batch_size
=
32
)
def
predict_dataloader
(
self
):
return
DataLoader
(
self
.
mnist_predict
,
batch_size
=
32
)
optimizer_before_step_flag
=
False
optimizer_after_step_flag
=
False
loss_flag
=
False
def
optimizer_before_step_patch
():
global
optimizer_before_step_flag
optimizer_before_step_flag
=
True
def
optimizer_after_step_patch
():
global
optimizer_after_step_flag
optimizer_after_step_flag
=
True
def
loss_patch
(
t
:
torch
.
Tensor
):
global
loss_flag
loss_flag
=
True
return
t
def
tensor_hook_factory
(
buffer
:
list
):
def
hook_func
(
t
:
torch
.
Tensor
):
buffer
.
append
(
True
)
return
hook_func
def
forward_hook_factory
(
buffer
:
list
):
def
hook_func
(
module
:
torch
.
nn
.
Module
,
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
buffer
.
append
(
True
)
return
hook_func
def
backward_hook_factory
(
buffer
:
list
):
def
hook_func
(
module
:
torch
.
nn
.
Module
,
grad_input
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
):
buffer
.
append
(
True
)
return
hook_func
def
reset_flags
():
global
optimizer_before_step_flag
,
optimizer_after_step_flag
,
loss_flag
optimizer_before_step_flag
=
False
optimizer_after_step_flag
=
False
loss_flag
=
False
def
assert_flags
():
global
optimizer_before_step_flag
,
optimizer_after_step_flag
,
loss_flag
assert
optimizer_before_step_flag
,
'Evaluator patch optimizer before step failed.'
assert
optimizer_after_step_flag
,
'Evaluator patch optimizer after step failed.'
assert
loss_flag
,
'Evaluator patch loss failed.'
def
create_lighting_evaluator
():
pl_model
=
SimpleLightningModel
()
pl_trainer
=
nni
.
trace
(
pl
.
Trainer
)(
max_epochs
=
1
,
max_steps
=
10
,
logger
=
TensorBoardLogger
(
Path
(
__file__
).
parent
/
'lightning_logs'
,
name
=
"resnet"
),
)
pl_trainer
.
num_sanity_val_steps
=
0
pl_data
=
nni
.
trace
(
MNISTDataModule
)(
data_dir
=
Path
(
__file__
).
parent
/
'data'
)
evaluator
=
LightningEvaluator
(
pl_trainer
,
pl_data
)
evaluator
.
_init_optimizer_helpers
(
pl_model
)
return
evaluator
def
create_pytorch_evaluator
():
model
=
SimpleTorchModel
()
optimizer
=
nni
.
trace
(
torch
.
optim
.
SGD
)(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
lr_scheduler
=
nni
.
trace
(
ExponentialLR
)(
optimizer
,
0.1
)
evaluator
=
TorchEvaluator
(
training_model
,
optimizer
,
F
.
nll_loss
,
lr_scheduler
,
evaluating_func
=
evaluating_model
)
evaluator
.
_init_optimizer_helpers
(
model
)
return
evaluator
@
pytest
.
mark
.
parametrize
(
"evaluator_type"
,
[
'lightning'
,
'pytorch'
])
def
test_evaluator
(
evaluator_type
:
str
):
if
evaluator_type
==
'lightning'
:
evaluator
=
create_lighting_evaluator
()
model
=
SimpleLightningModel
()
evaluator
.
bind_model
(
model
)
tensor_hook
=
TensorHook
(
model
.
model
.
conv1
.
weight
,
'model.conv1.weight'
,
tensor_hook_factory
)
forward_hook
=
ForwardHook
(
model
.
model
.
conv1
,
'model.conv1'
,
forward_hook_factory
)
backward_hook
=
BackwardHook
(
model
.
model
.
conv1
,
'model.conv1'
,
backward_hook_factory
)
elif
evaluator_type
==
'pytorch'
:
evaluator
=
create_pytorch_evaluator
()
model
=
SimpleTorchModel
().
to
(
device
)
evaluator
.
bind_model
(
model
)
tensor_hook
=
TensorHook
(
model
.
conv1
.
weight
,
'conv1.weight'
,
tensor_hook_factory
)
forward_hook
=
ForwardHook
(
model
.
conv1
,
'conv1'
,
forward_hook_factory
)
backward_hook
=
BackwardHook
(
model
.
conv1
,
'conv1'
,
backward_hook_factory
)
else
:
raise
ValueError
(
f
'wrong evaluator_type:
{
evaluator_type
}
'
)
# test train with patch & hook
reset_flags
()
evaluator
.
patch_loss
(
loss_patch
)
evaluator
.
patch_optimizer_step
([
optimizer_before_step_patch
],
[
optimizer_after_step_patch
])
evaluator
.
register_hooks
([
tensor_hook
,
forward_hook
,
backward_hook
])
evaluator
.
train
(
max_steps
=
1
)
assert_flags
()
assert
all
([
len
(
hook
.
buffer
)
==
1
for
hook
in
[
tensor_hook
,
forward_hook
,
backward_hook
]])
# test finetune with patch & hook
reset_flags
()
evaluator
.
remove_all_hooks
()
evaluator
.
register_hooks
([
tensor_hook
,
forward_hook
,
backward_hook
])
evaluator
.
finetune
()
assert_flags
()
assert
all
([
len
(
hook
.
buffer
)
==
10
for
hook
in
[
tensor_hook
,
forward_hook
,
backward_hook
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment