Unverified Commit 459de22d authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #918 from a-maci/ASP_sparse_param_dict_update

Asp sparse param dict update
parents 0ac5dd62 b3c16411
...@@ -30,7 +30,7 @@ class ASP: ...@@ -30,7 +30,7 @@ class ASP:
verbosity=3, verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[], allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False): allow_recompute_mask=False, custom_layer_dict={}):
"""Call this method to modify your model to take advantage of sparse matrix multiplication. """Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA, Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA. it does not enable use of sparse MMA.
...@@ -62,7 +62,9 @@ class ASP: ...@@ -62,7 +62,9 @@ class ASP:
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity. disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored. allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage. Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
Support for allow_recompute_mask can be removed, it is not part of our recipe -- AKM. custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM.
""" """
assert (cls.__model is None), "ASP has been initialized already." assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model cls.__model = model
...@@ -82,6 +84,10 @@ class ASP: ...@@ -82,6 +84,10 @@ class ASP:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']} sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
else: else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']} sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
sparse_parameter_list.update(custom_layer_dict)
whitelist += list(custom_layer_dict.keys())
for module_type in whitelist: for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype() assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
...@@ -110,6 +116,9 @@ class ASP: ...@@ -110,6 +116,9 @@ class ASP:
else: else:
pruned = None pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned)) cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
else:
if cls.__verbosity >= 3:
print("[ASP] Not sparsifying %s::%s of size=%s and type=%s" % (module_name, p_name, str(p.size()), str(p.dtype)))
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names): for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module) add_sparse_attributes(name, sparse_module)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment