Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
2a5b726a
Commit
2a5b726a
authored
Jun 10, 2020
by
Kirthi Sivamani
Browse files
asp files
parent
097238f8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
735 additions
and
0 deletions
+735
-0
apex/contrib/sparsity/__init__.py
apex/contrib/sparsity/__init__.py
+2
-0
apex/contrib/sparsity/asp.py
apex/contrib/sparsity/asp.py
+207
-0
apex/contrib/sparsity/sparse_masklib.py
apex/contrib/sparsity/sparse_masklib.py
+170
-0
apex/contrib/sparsity/test/checkpointing_test_part1.py
apex/contrib/sparsity/test/checkpointing_test_part1.py
+94
-0
apex/contrib/sparsity/test/checkpointing_test_part2.py
apex/contrib/sparsity/test/checkpointing_test_part2.py
+79
-0
apex/contrib/sparsity/test/checkpointing_test_reference.py
apex/contrib/sparsity/test/checkpointing_test_reference.py
+96
-0
apex/contrib/sparsity/test/toy_problem.py
apex/contrib/sparsity/test/toy_problem.py
+87
-0
No files found.
apex/contrib/sparsity/__init__.py
0 → 100644
View file @
2a5b726a
from
.sparse_masklib
import
create_mask
from
.asp
import
ASP
apex/contrib/sparsity/asp.py
0 → 100644
View file @
2a5b726a
import
types
import
torch
from
.sparse_masklib
import
create_mask
torchvision_imported
=
True
try
:
import
torchvision
except
ImportError
:
print
(
"[ASP][Warning] torchvision cannot be imported, may infuence functionality of MaskRCNN/KeypointRCNN network from torchvision."
)
torchvision_imported
=
False
def
eligible_modules
(
model
,
whitelist_layer_types
,
allowed_layer_names
,
disallowed_layer_names
):
eligible_modules_list
=
[]
for
name
,
mod
in
model
.
named_modules
():
if
isinstance
(
mod
,
whitelist_layer_types
)
and
name
not
in
disallowed_layer_names
:
if
allowed_layer_names
is
not
None
and
name
not
in
allowed_layer_names
:
continue
eligible_modules_list
.
append
((
name
,
mod
))
return
eligible_modules_list
class
ASP
:
__model
=
None
__verbosity
=
0
__optimizer
=
None
__sparse_parameters
=
[]
__calculate_mask
=
None
@
classmethod
def
init_model_for_pruning
(
cls
,
model
,
mask_calculator
=
"m4n2_1d"
,
verbosity
=
3
,
whitelist
=
[
torch
.
nn
.
Linear
,
torch
.
nn
.
Conv1d
,
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Conv3d
],
allowed_layer_names
=
None
,
disallowed_layer_names
=
[],
allow_recompute_mask
=
False
):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
If you are starting with a fresh model:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.
If you are starting from a checkpoint:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
torch.load(...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
Arguments:
model The model
mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib.
verbosity Integer controling verbosity level.
0 -> Only errors.
1 -> Errors and warnings.
2 -> Errors, warnings and info.
3 -> Errors, warnings, info and debug.
whitelist Module types approved for sparsity.
allowed_layer_names If not None, only layer names that appear in this list are considered for sparsity.
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
Support for allow_recompute_mask can be removed, it is not part of our recipe -- AKM.
"""
assert
(
cls
.
__model
is
None
),
"ASP has been initialized already."
cls
.
__model
=
model
cls
.
__verbosity
=
verbosity
if
isinstance
(
mask_calculator
,
str
):
def
create_mask_from_pattern
(
param
):
return
create_mask
(
param
,
mask_calculator
).
bool
()
cls
.
__calculate_mask
=
create_mask_from_pattern
else
:
cls
.
__calculate_mask
=
mask_calculator
#user defined function
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if
torchvision_imported
:
print
(
"[ASP] torchvision is imported, can work smoothly with the MaskRCNN/KeypointRCNN from torchvision."
)
sparse_parameter_list
=
{
torch
.
nn
.
Linear
:
[
'weight'
],
torch
.
nn
.
Conv1d
:
[
'weight'
],
torch
.
nn
.
Conv2d
:
[
'weight'
],
torch
.
nn
.
Conv3d
:
[
'weight'
],
torchvision
.
ops
.
misc
.
Conv2d
:
[
'weight'
]}
else
:
sparse_parameter_list
=
{
torch
.
nn
.
Linear
:
[
'weight'
],
torch
.
nn
.
Conv1d
:
[
'weight'
],
torch
.
nn
.
Conv2d
:
[
'weight'
],
torch
.
nn
.
Conv3d
:
[
'weight'
]}
for
module_type
in
whitelist
:
assert
(
module_type
in
sparse_parameter_list
),
"Module %s :: Don't know how to sparsify module."
%
module
.
dtype
()
# find all sparse modules, extract sparse parameters and decorate
def
add_sparse_attributes
(
module_name
,
module
):
sparse_parameters
=
sparse_parameter_list
[
type
(
module
)]
for
p_name
,
p
in
module
.
named_parameters
():
if
p_name
in
sparse_parameters
and
p
.
requires_grad
:
# check for NVIDIA's TC compatibility: we check along the horizontal direction
if
p
.
dtype
==
torch
.
float32
and
((
p
.
size
()[
0
]
%
8
)
!=
0
or
(
p
.
size
()[
1
]
%
16
)
!=
0
):
#User defines FP32 and APEX internally uses FP16 math
print
(
"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
%
(
module_name
,
p_name
,
str
(
p
.
size
()),
str
(
p
.
dtype
)))
continue
if
p
.
dtype
==
torch
.
float16
and
((
p
.
size
()[
0
]
%
8
)
!=
0
or
(
p
.
size
()[
1
]
%
16
)
!=
0
):
#For Conv2d dim= K x CRS; we prune along C
print
(
"[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity"
%
(
module_name
,
p_name
,
str
(
p
.
size
()),
str
(
p
.
dtype
)))
continue
if
cls
.
__verbosity
>=
3
:
print
(
"[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity"
%
(
module_name
,
p_name
,
str
(
p
.
size
()),
str
(
p
.
dtype
)))
mask
=
torch
.
ones_like
(
p
).
bool
()
buffname
=
name
.
split
(
"."
)[
-
1
]
# buffer names cannot contain "."
module
.
register_buffer
(
'__%s_mma_mask'
%
buffname
,
mask
)
if
allow_recompute_mask
:
pruned
=
torch
.
zeros_like
(
p
).
cpu
()
module
.
register_buffer
(
'__%s_mma_pruned_p'
%
buffname
,
pruned
)
else
:
pruned
=
None
cls
.
__sparse_parameters
.
append
((
module_name
,
module
,
p_name
,
p
,
mask
,
pruned
))
for
name
,
sparse_module
in
eligible_modules
(
model
,
tuple
(
whitelist
),
allowed_layer_names
,
disallowed_layer_names
):
add_sparse_attributes
(
name
,
sparse_module
)
@
classmethod
def
init_optimizer_for_pruning
(
cls
,
optimizer
):
"""Call this method to monkey patch optimizer step function so that masks can be applied to
gradients and weights during training.
You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)
"""
assert
(
cls
.
__optimizer
is
None
),
"ASP has initialized optimizer already."
assert
(
cls
.
__calculate_mask
is
not
None
),
"Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."
# store pointer to original optimizer step method
cls
.
__optimizer
=
optimizer
cls
.
__optimizer
.
__step
=
optimizer
.
step
def
__step
(
opt_self
,
*
args
,
**
kwargs
):
# prune gradients before step method
with
torch
.
no_grad
():
for
module_name
,
module
,
p_name
,
p
,
mask
,
pruned
in
cls
.
__sparse_parameters
:
p
.
grad
.
mul_
(
mask
)
# call original optimizer step method
rval
=
opt_self
.
__step
(
*
args
,
**
kwargs
)
# prune parameters after step method
with
torch
.
no_grad
():
for
module_name
,
module
,
p_name
,
p
,
mask
,
pruned
in
cls
.
__sparse_parameters
:
p
.
mul_
(
mask
)
return
rval
cls
.
__optimizer
.
step
=
types
.
MethodType
(
__step
,
cls
.
__optimizer
)
@
classmethod
def
compute_sparse_masks
(
cls
):
"""Call this method to enable sparsity.
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
with
torch
.
no_grad
():
for
module_name
,
module
,
p_name
,
p
,
mask
,
pruned
in
cls
.
__sparse_parameters
:
if
mask
.
sum
()
<
mask
.
numel
():
# when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
assert
(
pruned
is
not
None
),
"Unable to restore dense parameter because allow_recompute_mask == False"
p
.
add_
(
pruned
.
cuda
())
mask
.
set_
(
cls
.
__calculate_mask
(
p
))
if
pruned
is
not
None
:
# stow away pruned weights to cpu
pruned
.
set_
((
p
*
(
~
mask
)).
cpu
())
p
.
mul_
(
mask
)
# in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if
cls
.
__verbosity
>=
2
:
print
(
"[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s"
%
(
100.0
*
mask
.
sum
()
/
mask
.
numel
(),
module_name
,
p_name
,
str
(
p
.
size
()),
str
(
p
.
dtype
)))
@
classmethod
def
restore_pruned_weights
(
cls
):
"""Call this method to disable sparsity and restore all weights.
This will only work if init(...) was called with allow_recompute=True.
"""
with
torch
.
no_grad
():
for
module_name
,
module
,
p_name
,
p
,
mask
,
pruned
in
cls
.
__sparse_parameters
:
if
mask
.
sum
()
<
mask
.
numel
():
assert
(
pruned
is
not
None
),
"Unable to restore dense parameter because allow_recompute_mask == False"
p
.
add_
(
pruned
.
cuda
())
mask
.
fill_
(
1
)
pruned
.
zero_
()
if
cls
.
__verbosity
>=
2
:
print
(
"[ASP] Disabled sparsity for %s::%s (dense weights restored)"
%
(
module_name
,
p_name
))
@
classmethod
def
is_sparsity_enabled
(
cls
):
"""Call this method to determine if sparsity is enabled in the model.
The typical use case is right after checkpoint has been loaded.
"""
total
,
sp100
,
sp50
=
0
,
0
,
0
for
module_name
,
module
,
p_name
,
p
,
mask
,
pruned
in
cls
.
__sparse_parameters
:
total
+=
1
mask_sum
=
mask
.
sum
()
mask_numel
=
mask
.
numel
()
if
mask_sum
==
mask_numel
:
sp100
+=
1
elif
mask_sum
*
2
==
mask_numel
:
sp50
+=
1
assert
(
total
==
sp100
or
total
==
sp50
),
"Inconsistent model sparsity"
if
total
==
sp100
:
return
False
elif
total
==
sp50
:
return
True
@
classmethod
def
prune_trained_model
(
cls
,
model
,
optimizer
):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls
.
init_model_for_pruning
(
model
,
mask_calculator
=
"m4n2_1d"
,
verbosity
=
2
,
whitelist
=
[
torch
.
nn
.
Linear
,
torch
.
nn
.
Conv2d
],
allow_recompute_mask
=
False
)
cls
.
init_optimizer_for_pruning
(
optimizer
)
cls
.
compute_sparse_masks
()
apex/contrib/sparsity/sparse_masklib.py
0 → 100644
View file @
2a5b726a
import
sys
import
torch
import
numpy
as
np
import
collections
from
itertools
import
permutations
""" compute density (helper fn to compute % NNZs in a tensor)"""
def
fill
(
x
):
return
float
(
x
.
nonzero
().
size
(
0
))
/
torch
.
numel
(
x
)
""" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) """
def
reshape_1d
(
matrix
,
m
):
# If not a nice multiple of m, fill with zeroes.
if
matrix
.
shape
[
1
]
%
m
>
0
:
mat
=
torch
.
cuda
.
FloatTensor
(
matrix
.
shape
[
0
],
matrix
.
shape
[
1
]
+
(
m
-
matrix
.
shape
[
1
]
%
m
)).
fill_
(
0
)
mat
[:,
:
matrix
.
shape
[
1
]]
=
matrix
shape
=
mat
.
shape
return
mat
.
view
(
-
1
,
m
),
shape
else
:
return
matrix
.
view
(
-
1
,
m
),
matrix
.
shape
""" return all possible m:n patterns in a 1d vector. """
valid_m4n2_1d_patterns
=
None
def
compute_valid_1d_patterns
(
m
,
n
):
# Early exit if patterns was already created.
global
valid_m4n2_1d_patterns
if
m
==
4
and
n
==
2
and
valid_m4n2_1d_patterns
is
not
None
:
return
valid_m4n2_1d_patterns
patterns
=
torch
.
zeros
(
m
)
patterns
[:
n
]
=
1
valid_patterns
=
torch
.
Tensor
(
list
(
set
(
permutations
(
patterns
.
tolist
()))))
if
m
==
4
and
n
==
2
:
valid_m4n2_1d_patterns
=
valid_patterns
return
valid_patterns
""" m:n 1d structured best """
def
mn_1d_best
(
matrix
,
m
,
n
):
# Find all possible patterns.
patterns
=
compute_valid_1d_patterns
(
m
,
n
).
cuda
()
# Find the best m:n pattern (sum of non-masked weights).
mask
=
torch
.
cuda
.
IntTensor
(
matrix
.
shape
).
fill_
(
1
).
view
(
-
1
,
m
)
mat
,
shape
=
reshape_1d
(
matrix
,
m
)
pmax
=
torch
.
argmax
(
torch
.
matmul
(
mat
.
abs
(),
patterns
.
t
()),
dim
=
1
)
mask
[:]
=
patterns
[
pmax
[:]]
mask
=
mask
.
view
(
matrix
.
shape
)
return
mask
def
m4n2_1d
(
mat
,
density
):
return
mn_1d_best
(
mat
,
4
,
2
)
""" Comment: Following 2d masking related code (for training) can be removed or marked experimental (78 LOC) """
""" m:n 2d structured greedy """
def
mn_2d_greedy
(
matrix
,
m
,
n
):
# Convert to numpy
mat
=
matrix
.
cpu
().
detach
().
numpy
()
mask
=
np
.
ones
(
mat
.
shape
,
dtype
=
int
)
rowCount
=
int
(
mat
.
shape
[
0
]
/
m
)
*
m
colCount
=
int
(
mat
.
shape
[
1
]
/
m
)
*
m
for
rowStartIdx
in
range
(
0
,
rowCount
,
m
):
rowEndIdx
=
rowStartIdx
+
m
for
colStartIdx
in
range
(
0
,
colCount
,
m
):
colEndIdx
=
colStartIdx
+
m
matrixSub
=
np
.
absolute
(
np
.
squeeze
(
mat
[
rowStartIdx
:
rowEndIdx
,
colStartIdx
:
colEndIdx
]))
maskSub
=
np
.
squeeze
(
mask
[
rowStartIdx
:
rowEndIdx
,
colStartIdx
:
colEndIdx
])
maskSub
.
fill
(
0.0
)
matrixVecView
=
matrixSub
.
reshape
(
-
1
)
maskVecView
=
maskSub
.
reshape
(
-
1
)
linearIdx
=
np
.
argsort
(
matrixVecView
)
matrixIdx
=
[(
int
(
x
/
m
),
x
%
m
)
for
x
in
linearIdx
]
rowCounter
=
collections
.
Counter
()
colCounter
=
collections
.
Counter
()
for
currIdx
in
range
(
len
(
linearIdx
)
-
1
,
-
1
,
-
1
):
currMatrixEntry
=
matrixIdx
[
currIdx
]
if
(
rowCounter
[
currMatrixEntry
[
0
]]
==
n
)
or
(
colCounter
[
currMatrixEntry
[
1
]]
==
n
):
continue
#end if
maskSub
[
currMatrixEntry
[
0
],
currMatrixEntry
[
1
]]
=
1.0
rowCounter
[
currMatrixEntry
[
0
]]
+=
1
colCounter
[
currMatrixEntry
[
1
]]
+=
1
return
torch
.
tensor
(
mask
.
cuda
())
def
m4n2_2d_greedy
(
mat
,
density
):
return
mn_2d_greedy
(
mat
,
4
,
2
)
""" return all possible m:n patterns in a mxn block. """
valid_m4n2_2d_patterns
=
None
def
compute_valid_2d_patterns
(
m
,
n
):
# Early exit if patterns was already created.
global
valid_m4n2_2d_patterns
if
valid_m4n2_2d_patterns
is
not
None
:
return
valid_m4n2_2d_patterns
patterns
=
torch
.
zeros
(
m
)
patterns
[:
n
]
=
1
patterns
=
list
(
set
(
permutations
(
patterns
.
tolist
())))
patterns
=
patterns
+
patterns
patterns
=
torch
.
Tensor
(
list
(
set
(
permutations
(
patterns
,
m
))))
valid
=
((
patterns
.
sum
(
dim
=
1
)
<=
n
).
sum
(
dim
=
1
)
==
m
).
nonzero
().
view
(
-
1
)
valid_patterns
=
torch
.
Tensor
(
valid
.
shape
[
0
],
m
,
m
)
valid_patterns
[:]
=
patterns
[
valid
[:]]
if
m
==
4
and
n
==
2
:
valid_m4n2_2d_patterns
=
valid_patterns
return
valid_patterns
""" m:n 2d structured best """
def
mn_2d_best
(
matrix
,
m
,
n
):
# Find all possible patterns.
patterns
=
compute_valid_2d_patterns
(
m
,
n
).
cuda
()
# Find the best m:n pattern (sum of non-masked weights).
mask
=
torch
.
cuda
.
IntTensor
(
matrix
.
shape
).
fill_
(
1
)
mat
=
reshape_2d
(
matrix
,
m
,
m
).
abs
()
pmax
=
torch
.
argmax
(
torch
.
matmul
(
mat
,
patterns
.
view
(
patterns
.
shape
[
0
],
m
*
m
).
t
()),
dim
=
2
)
# Copy best m:n patterns into mask.
mat
=
mat
.
view
(
mat
.
shape
[
0
]
*
mat
.
shape
[
1
],
-
1
)
pmax
=
pmax
.
view
(
pmax
.
shape
[
0
]
*
pmax
.
shape
[
1
]).
unsqueeze
(
1
).
expand
(
-
1
,
mat
.
shape
[
1
])
patterns
=
patterns
.
view
(
patterns
.
shape
[
0
],
patterns
.
shape
[
1
]
*
patterns
.
shape
[
2
])
mat
=
torch
.
gather
(
patterns
,
0
,
pmax
)
mat
=
reshape_2d_inv
(
mat
.
view
(
matrix
.
shape
[
0
]
//
m
,
matrix
.
shape
[
1
]
//
m
,
m
,
m
))
mask
.
copy_
(
mat
.
type
(
mask
.
type
()))
return
mask
def
m4n2_2d_best
(
mat
,
density
):
return
mn_2d_best
(
mat
,
4
,
2
)
""" returns a sparse mask """
def
create_mask
(
tensor
,
pattern
=
"m4n2_1d"
,
density
=
0.5
):
# Reshape tensor and mask.
shape
=
tensor
.
shape
ttype
=
tensor
.
type
()
t
=
tensor
.
float
().
contiguous
()
# 1d-tensor
if
len
(
shape
)
==
1
:
t
=
t
.
view
(
1
,
shape
[
0
])
func
=
getattr
(
sys
.
modules
[
__name__
],
pattern
,
None
)
mask
=
func
(
t
,
density
)
return
mask
.
view
(
shape
).
type
(
ttype
)
# 2d-tensor (in, out)
elif
len
(
shape
)
==
2
:
t
=
t
.
view
(
shape
[
0
],
shape
[
1
])
func
=
getattr
(
sys
.
modules
[
__name__
],
pattern
,
None
)
mask
=
func
(
t
,
density
)
return
mask
.
view
(
shape
).
type
(
ttype
)
# 3d-tensor (batch, in, out)
elif
len
(
shape
)
==
3
:
t
=
t
.
view
(
shape
[
0
]
*
shape
[
1
],
shape
[
2
])
func
=
getattr
(
sys
.
modules
[
__name__
],
pattern
,
None
)
mask
=
func
(
t
,
density
)
return
mask
.
view
(
shape
).
type
(
ttype
)
# 4d-tensor (in, out, h, w)
elif
len
(
shape
)
==
4
:
"""
# transformers (bmm)
t = t.view(shape[0]*shape[1]*shape[2], shape[3])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
"""
# convs
t
=
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
().
view
(
shape
[
2
]
*
shape
[
3
]
*
shape
[
0
],
shape
[
1
])
func
=
getattr
(
sys
.
modules
[
__name__
],
pattern
,
None
)
mask
=
func
(
t
,
density
)
mask
=
mask
.
view
(
shape
[
2
],
shape
[
3
],
shape
[
0
],
shape
[
1
]).
permute
(
2
,
3
,
0
,
1
).
contiguous
()
return
mask
.
view
(
shape
).
type
(
ttype
)
apex/contrib/sparsity/test/checkpointing_test_part1.py
0 → 100644
View file @
2a5b726a
from
collections
import
OrderedDict
import
torch
from
apex.optimizers
import
FusedAdam
from
apex.contrib.sparsity
import
ASP
def
build_model
(
args
):
od
=
OrderedDict
()
for
i
in
range
(
args
.
num_layers
):
if
i
==
0
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
input_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
elif
i
==
args
.
num_layers
-
1
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
output_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
output_features
])
else
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
return
torch
.
nn
.
Sequential
(
od
)
def
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
):
predicted_target
=
model
(
input_batch
)
loss
=
((
predicted_target
-
target_batch
)
**
2
).
sum
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
zero_grad
()
step
=
step
+
1
#print("Step %d :: loss=%e" % (step, loss.item()))
return
step
def
train_loop
(
args
,
model
,
optimizer
,
step
,
num_steps
):
for
i
in
range
(
num_steps
):
input_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
input_features
]).
cuda
()
target_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
output_features
]).
cuda
()
step
=
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
)
return
step
def
main
(
args
):
#
# PART1
#
torch
.
manual_seed
(
args
.
seed
)
model
=
build_model
(
args
).
cuda
()
one_ll
=
next
(
model
.
children
()).
weight
optimizer
=
FusedAdam
(
model
.
parameters
())
ASP
.
init_model_for_pruning
(
model
,
args
.
pattern
,
verbosity
=
args
.
verbosity
,
whitelist
=
args
.
whitelist
,
allow_recompute_mask
=
args
.
allow_recompute_mask
)
ASP
.
init_optimizer_for_pruning
(
optimizer
)
step
=
0
# train for a few steps with dense weights
print
(
"DENSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_dense_steps
)
# simulate sparsity by inserting zeros into existing dense weights
ASP
.
enable_sparsity
()
# train for a few steps with sparse weights
print
(
"SPARSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_sparse_steps
)
torch
.
save
({
'step'
:
step
,
'verbosity'
:
args
.
verbosity
,
'seed2'
:
args
.
seed2
,
'pattern'
:
args
.
pattern
,
'whitelist'
:
args
.
whitelist
,
'allow_recompute_mask'
:
args
.
allow_recompute_mask
,
'model_state_dict'
:
model
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
},
args
.
checkpoint_path
)
if
__name__
==
'__main__'
:
class
Args
:
verbosity
=
3
seed
=
4873
seed2
=
99875
pattern
=
"m4n2_2d_best"
whitelist
=
[
torch
.
nn
.
Linear
]
allow_recompute_mask
=
True
batch_size
=
32
input_features
=
8
output_features
=
8
hidden_features
=
32
num_layers
=
4
num_dense_steps
=
2000
num_sparse_steps
=
3000
num_sparse_steps_2
=
1000
checkpoint_path
=
"part1.chkp"
args
=
Args
()
main
(
args
)
apex/contrib/sparsity/test/checkpointing_test_part2.py
0 → 100644
View file @
2a5b726a
from
collections
import
OrderedDict
import
torch
from
apex.optimizers
import
FusedAdam
from
apex.contrib.sparsity
import
ASP
def
build_model
(
args
):
od
=
OrderedDict
()
for
i
in
range
(
args
.
num_layers
):
if
i
==
0
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
input_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
elif
i
==
args
.
num_layers
-
1
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
output_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
output_features
])
else
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
return
torch
.
nn
.
Sequential
(
od
)
def
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
):
predicted_target
=
model
(
input_batch
)
loss
=
((
predicted_target
-
target_batch
)
**
2
).
sum
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
zero_grad
()
step
=
step
+
1
#print("Step %d :: loss=%e" % (step, loss.item()))
return
step
def
train_loop
(
args
,
model
,
optimizer
,
step
,
num_steps
):
for
i
in
range
(
num_steps
):
input_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
input_features
]).
cuda
()
target_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
output_features
]).
cuda
()
step
=
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
)
return
step
def
main
(
step
,
args
,
model_state_dict
,
optimizer_state_dict
):
#
# PART2
#
model
=
build_model
(
args
).
cuda
()
one_ll
=
next
(
model
.
children
()).
weight
optimizer
=
FusedAdam
(
model
.
parameters
())
ASP
.
init_model_for_pruning
(
model
,
args
.
pattern
,
verbosity
=
args
.
verbosity
,
whitelist
=
args
.
whitelist
,
allow_recompute_mask
=
args
.
allow_recompute_mask
)
ASP
.
init_optimizer_for_pruning
(
optimizer
)
torch
.
manual_seed
(
args
.
seed2
)
model
.
load_state_dict
(
model_state_dict
)
optimizer
.
load_state_dict
(
optimizer_state_dict
)
print
(
"Model sparsity is %s"
%
(
"enabled"
if
ASP
.
sparsity_is_enabled
()
else
"disabled"
))
# train for a few steps with sparse weights
print
(
"SPARSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_sparse_steps_2
)
if
__name__
==
'__main__'
:
checkpoint
=
torch
.
load
(
"part1.chkp"
)
class
Args
:
verbosity
=
checkpoint
[
'verbosity'
]
seed
=
4873
seed2
=
checkpoint
[
'seed2'
]
pattern
=
checkpoint
[
'pattern'
]
whitelist
=
checkpoint
[
'whitelist'
]
allow_recompute_mask
=
checkpoint
[
'allow_recompute_mask'
]
batch_size
=
32
input_features
=
8
output_features
=
8
hidden_features
=
32
num_layers
=
4
num_dense_steps
=
2000
num_sparse_steps
=
3000
num_sparse_steps_2
=
1000
checkpoint_path
=
"part1.chkp"
args
=
Args
()
main
(
checkpoint
[
'step'
],
args
,
checkpoint
[
'model_state_dict'
],
checkpoint
[
'optimizer_state_dict'
])
apex/contrib/sparsity/test/checkpointing_test_reference.py
0 → 100644
View file @
2a5b726a
from
collections
import
OrderedDict
import
torch
from
apex.optimizers
import
FusedAdam
from
apex.contrib.sparsity
import
ASP
#
# Reference run for checkpointing test (part1 + part2)
#
def
build_model
(
args
):
od
=
OrderedDict
()
for
i
in
range
(
args
.
num_layers
):
if
i
==
0
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
input_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
elif
i
==
args
.
num_layers
-
1
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
output_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
output_features
])
else
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
return
torch
.
nn
.
Sequential
(
od
)
def
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
):
predicted_target
=
model
(
input_batch
)
loss
=
((
predicted_target
-
target_batch
)
**
2
).
sum
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
zero_grad
()
step
=
step
+
1
#print("Step %d :: loss=%e" % (step, loss.item()))
return
step
def
train_loop
(
args
,
model
,
optimizer
,
step
,
num_steps
):
for
i
in
range
(
num_steps
):
input_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
input_features
]).
cuda
()
target_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
output_features
]).
cuda
()
step
=
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
)
return
step
def
main
(
args
):
#
# PART1
#
torch
.
manual_seed
(
args
.
seed
)
model
=
build_model
(
args
).
cuda
()
one_ll
=
next
(
model
.
children
()).
weight
optimizer
=
FusedAdam
(
model
.
parameters
())
ASP
.
init_model_for_pruning
(
model
,
args
.
pattern
,
whitelist
=
args
.
whitelist
,
allow_recompute_mask
=
args
.
allow_recompute_mask
)
ASP
.
init_optimizer_for_pruning
(
optimizer
)
step
=
0
# train for a few steps with dense weights
print
(
"DENSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_dense_steps
)
# simulate sparsity by inserting zeros into existing dense weights
ASP
.
enable_sparsity
()
# train for a few steps with sparse weights
print
(
"SPARSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_sparse_steps
)
#
# PART 2
#
torch
.
manual_seed
(
args
.
seed2
)
# train for a few steps with sparse weights
print
(
"SPARSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_sparse_steps_2
)
if
__name__
==
'__main__'
:
class
Args
:
seed
=
4873
seed2
=
99875
pattern
=
"m4n2_2d_best"
whitelist
=
[
torch
.
nn
.
Linear
]
allow_recompute_mask
=
True
batch_size
=
32
input_features
=
8
output_features
=
8
hidden_features
=
32
num_layers
=
4
num_dense_steps
=
2000
num_sparse_steps
=
3000
num_sparse_steps_2
=
1000
checkpoint_path
=
"part1.chkp"
args
=
Args
()
main
(
args
)
apex/contrib/sparsity/test/toy_problem.py
0 → 100644
View file @
2a5b726a
from
collections
import
OrderedDict
import
torch
from
apex.optimizers
import
FusedAdam
from
sparsity.apex.contrib.sparsity
import
ASP
def
build_model
(
args
):
od
=
OrderedDict
()
for
i
in
range
(
args
.
num_layers
):
if
i
==
0
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
input_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
elif
i
==
args
.
num_layers
-
1
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
output_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
output_features
])
else
:
od
[
'linear_layer_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
Linear
(
args
.
hidden_features
,
args
.
hidden_features
)
od
[
'layer_norm_%d'
%
(
i
+
1
)]
=
torch
.
nn
.
LayerNorm
([
args
.
batch_size
,
args
.
hidden_features
])
return
torch
.
nn
.
Sequential
(
od
)
def
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
):
predicted_target
=
model
(
input_batch
)
loss
=
((
predicted_target
-
target_batch
)
**
2
).
sum
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
zero_grad
()
step
=
step
+
1
#print("Step %d :: loss=%e" % (step, loss.item()))
return
step
def
train_loop
(
args
,
model
,
optimizer
,
step
,
num_steps
):
for
i
in
range
(
num_steps
):
input_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
input_features
]).
cuda
()
target_batch
=
torch
.
randn
([
args
.
batch_size
,
args
.
output_features
]).
cuda
()
step
=
train_step
(
args
,
model
,
optimizer
,
input_batch
,
target_batch
,
step
)
return
step
def
main
(
args
):
model
=
build_model
(
args
).
cuda
()
one_ll
=
next
(
model
.
children
()).
weight
optimizer
=
FusedAdam
(
model
.
parameters
())
# only prune linear layers, even though we also support conv1d, conv2d and conv3d
ASP
.
init_model_for_pruning
(
model
,
"m4n2_1d"
,
whitelist
=
[
torch
.
nn
.
Linear
],
allow_recompute_mask
=
True
)
ASP
.
init_optimizer_for_pruning
(
optimizer
)
step
=
0
# train for a few steps with dense weights
print
(
"DENSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_dense_steps
)
# simulate sparsity by inserting zeros into existing dense weights
ASP
.
compute_sparse_masks
()
# train for a few steps with sparse weights
print
(
"SPARSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_sparse_steps
)
# recompute sparse masks
ASP
.
compute_sparse_masks
()
# train for a few steps with sparse weights
print
(
"SPARSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_sparse_steps_2
)
# turn off sparsity
print
(
"SPARSE :: "
,
one_ll
)
ASP
.
restore_pruned_weights
()
# train for a few steps with dense weights
print
(
"DENSE :: "
,
one_ll
)
step
=
train_loop
(
args
,
model
,
optimizer
,
step
,
args
.
num_dense_steps_2
)
if
__name__
==
'__main__'
:
class
Args
:
batch_size
=
32
input_features
=
16
output_features
=
8
hidden_features
=
40
num_layers
=
4
num_dense_steps
=
2000
num_sparse_steps
=
3000
num_sparse_steps_2
=
1000
num_dense_steps_2
=
1500
args
=
Args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment