Unverified Commit 4116ed66 authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

Merge pull request #30 from lcskrishna/ifu_07272020

IFU-master 07/27/2020.
parents 8dd19e3b 1c664582
import math
import torch
from torch import nn
from torch.nn import Parameter
......@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module):
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_kv)
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.)
......
import math
import torch
from torch import nn
from torch.nn import Parameter
......@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module):
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
nn.init.xavier_uniform_(self.in_proj_weight)
# in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
if self.separate_qkv_params:
......
......@@ -30,7 +30,7 @@ class ASP:
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False):
allow_recompute_mask=False, custom_layer_dict={}):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
......@@ -62,7 +62,9 @@ class ASP:
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
Support for allow_recompute_mask can be removed, it is not part of our recipe -- AKM.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM.
"""
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
......@@ -82,6 +84,10 @@ class ASP:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
sparse_parameter_list.update(custom_layer_dict)
whitelist += list(custom_layer_dict.keys())
for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
......@@ -97,12 +103,12 @@ class ASP:
if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue
if cls.__verbosity >= 3:
print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
mask = torch.ones_like(p).bool()
buffname = name.split(".")[-1] # buffer names cannot contain "."
buffname = p_name.split(".")[-1] # buffer names cannot contain "."
module.register_buffer('__%s_mma_mask' % buffname, mask)
if allow_recompute_mask:
pruned = torch.zeros_like(p).cpu()
......@@ -110,6 +116,9 @@ class ASP:
else:
pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
else:
if cls.__verbosity >= 3:
print("[ASP] Not sparsifying %s::%s of size=%s and type=%s" % (module_name, p_name, str(p.size()), str(p.dtype)))
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)
......@@ -131,7 +140,8 @@ class ASP:
# prune gradients before step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.grad.mul_(mask)
if p.grad is not None: #thx pjudd
p.grad.mul_(mask)
# call original optimizer step method
rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment