Commit 2a5b726a authored by Kirthi Sivamani's avatar Kirthi Sivamani
Browse files

asp files

parent 097238f8
from .sparse_masklib import create_mask
from .asp import ASP
import types
import torch
from .sparse_masklib import create_mask
torchvision_imported=True
try:
import torchvision
except ImportError:
print("[ASP][Warning] torchvision cannot be imported, may infuence functionality of MaskRCNN/KeypointRCNN network from torchvision.")
torchvision_imported=False
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = []
for name, mod in model.named_modules():
if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:
if allowed_layer_names is not None and name not in allowed_layer_names:
continue
eligible_modules_list.append((name, mod))
return eligible_modules_list
class ASP:
__model = None
__verbosity = 0
__optimizer = None
__sparse_parameters = []
__calculate_mask = None
@classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
If you are starting with a fresh model:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.
If you are starting from a checkpoint:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
torch.load(...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
Arguments:
model The model
mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib.
verbosity Integer controling verbosity level.
0 -> Only errors.
1 -> Errors and warnings.
2 -> Errors, warnings and info.
3 -> Errors, warnings, info and debug.
whitelist Module types approved for sparsity.
allowed_layer_names If not None, only layer names that appear in this list are considered for sparsity.
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
Support for allow_recompute_mask can be removed, it is not part of our recipe -- AKM.
"""
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
cls.__verbosity = verbosity
if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator #user defined function
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if torchvision_imported:
print("[ASP] torchvision is imported, can work smoothly with the MaskRCNN/KeypointRCNN from torchvision.")
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)]
for p_name, p in module.named_parameters():
if p_name in sparse_parameters and p.requires_grad:
# check for NVIDIA's TC compatibility: we check along the horizontal direction
if p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #User defines FP32 and APEX internally uses FP16 math
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue
if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue
if cls.__verbosity >= 3:
print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
mask = torch.ones_like(p).bool()
buffname = name.split(".")[-1] # buffer names cannot contain "."
module.register_buffer('__%s_mma_mask' % buffname, mask)
if allow_recompute_mask:
pruned = torch.zeros_like(p).cpu()
module.register_buffer('__%s_mma_pruned_p' % buffname, pruned)
else:
pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)
@classmethod
def init_optimizer_for_pruning(cls, optimizer):
"""Call this method to monkey patch optimizer step function so that masks can be applied to
gradients and weights during training.
You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)
"""
assert (cls.__optimizer is None), "ASP has initialized optimizer already."
assert (cls.__calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."
# store pointer to original optimizer step method
cls.__optimizer = optimizer
cls.__optimizer.__step = optimizer.step
def __step(opt_self, *args, **kwargs):
# prune gradients before step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.grad.mul_(mask)
# call original optimizer step method
rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.mul_(mask)
return rval
cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)
@classmethod
def compute_sparse_masks(cls):
"""Call this method to enable sparsity.
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())
mask.set_(cls.__calculate_mask(p))
if pruned is not None: # stow away pruned weights to cpu
pruned.set_((p * (~mask)).cpu())
p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2:
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))
@classmethod
def restore_pruned_weights(cls):
"""Call this method to disable sparsity and restore all weights.
This will only work if init(...) was called with allow_recompute=True.
"""
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel():
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())
mask.fill_(1)
pruned.zero_()
if cls.__verbosity >= 2:
print("[ASP] Disabled sparsity for %s::%s (dense weights restored)" % (module_name, p_name))
@classmethod
def is_sparsity_enabled(cls):
"""Call this method to determine if sparsity is enabled in the model.
The typical use case is right after checkpoint has been loaded.
"""
total,sp100,sp50 = 0,0,0
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
total += 1
mask_sum = mask.sum()
mask_numel = mask.numel()
if mask_sum == mask_numel:
sp100 += 1
elif mask_sum*2 == mask_numel:
sp50 += 1
assert (total == sp100 or total == sp50), "Inconsistent model sparsity"
if total == sp100:
return False
elif total == sp50:
return True
@classmethod
def prune_trained_model(cls, model, optimizer):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()
import sys
import torch
import numpy as np
import collections
from itertools import permutations
""" compute density (helper fn to compute % NNZs in a tensor)"""
def fill(x):
return float(x.nonzero().size(0))/torch.numel(x)
""" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) """
def reshape_1d(matrix, m):
# If not a nice multiple of m, fill with zeroes.
if matrix.shape[1] % m > 0:
mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m-matrix.shape[1]%m)).fill_(0)
mat[:, :matrix.shape[1]] = matrix
shape = mat.shape
return mat.view(-1,m),shape
else:
return matrix.view(-1,m), matrix.shape
""" return all possible m:n patterns in a 1d vector. """
valid_m4n2_1d_patterns = None
def compute_valid_1d_patterns(m,n):
# Early exit if patterns was already created.
global valid_m4n2_1d_patterns
if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
return valid_patterns
""" m:n 1d structured best """
def mn_1d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_1d_patterns(m,n).cuda()
# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m)
mat,shape = reshape_1d(matrix,m)
pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask
def m4n2_1d(mat, density):
return mn_1d_best(mat, 4, 2)
""" Comment: Following 2d masking related code (for training) can be removed or marked experimental (78 LOC) """
""" m:n 2d structured greedy """
def mn_2d_greedy(matrix, m, n):
# Convert to numpy
mat = matrix.cpu().detach().numpy()
mask = np.ones(mat.shape, dtype=int)
rowCount = int(mat.shape[0]/m) * m
colCount = int(mat.shape[1]/m) * m
for rowStartIdx in range(0, rowCount, m):
rowEndIdx = rowStartIdx + m
for colStartIdx in range(0, colCount, m):
colEndIdx = colStartIdx + m
matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]))
maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])
maskSub.fill(0.0)
matrixVecView = matrixSub.reshape(-1)
maskVecView = maskSub.reshape(-1)
linearIdx = np.argsort(matrixVecView)
matrixIdx = [(int(x/m), x % m) for x in linearIdx]
rowCounter = collections.Counter()
colCounter = collections.Counter()
for currIdx in range(len(linearIdx) - 1, -1, -1):
currMatrixEntry = matrixIdx[currIdx]
if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n):
continue
#end if
maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0
rowCounter[currMatrixEntry[0]] += 1
colCounter[currMatrixEntry[1]] += 1
return torch.tensor(mask.cuda())
def m4n2_2d_greedy(mat, density):
return mn_2d_greedy(mat, 4, 2)
""" return all possible m:n patterns in a mxn block. """
valid_m4n2_2d_patterns = None
def compute_valid_2d_patterns(m,n):
# Early exit if patterns was already created.
global valid_m4n2_2d_patterns
if valid_m4n2_2d_patterns is not None: return valid_m4n2_2d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
patterns = list(set(permutations(patterns.tolist())))
patterns = patterns + patterns
patterns = torch.Tensor(list(set(permutations(patterns,m))))
valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)
valid_patterns = torch.Tensor(valid.shape[0],m,m)
valid_patterns[:] = patterns[valid[:]]
if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns
return valid_patterns
""" m:n 2d structured best """
def mn_2d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_2d_patterns(m,n).cuda()
# Find the best m:n pattern (sum of non-masked weights).
mask = torch.cuda.IntTensor(matrix.shape).fill_(1)
mat = reshape_2d(matrix,m,m).abs()
pmax = torch.argmax(torch.matmul(mat,patterns.view(patterns.shape[0],m*m).t()), dim=2)
# Copy best m:n patterns into mask.
mat = mat.view(mat.shape[0]*mat.shape[1],-1)
pmax = pmax.view(pmax.shape[0]*pmax.shape[1]).unsqueeze(1).expand(-1,mat.shape[1])
patterns = patterns.view(patterns.shape[0],patterns.shape[1]*patterns.shape[2])
mat = torch.gather(patterns,0,pmax)
mat = reshape_2d_inv(mat.view(matrix.shape[0]//m,matrix.shape[1]//m,m,m))
mask.copy_(mat.type(mask.type()))
return mask
def m4n2_2d_best(mat, density):
return mn_2d_best(mat, 4, 2)
""" returns a sparse mask """
def create_mask(tensor, pattern="m4n2_1d", density=0.5):
# Reshape tensor and mask.
shape = tensor.shape
ttype = tensor.type()
t = tensor.float().contiguous()
# 1d-tensor
if len(shape) == 1:
t = t.view(1, shape[0])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 2d-tensor (in, out)
elif len(shape) == 2:
t = t.view(shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 3d-tensor (batch, in, out)
elif len(shape) == 3:
t = t.view(shape[0]*shape[1], shape[2])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
# 4d-tensor (in, out, h, w)
elif len(shape) == 4:
"""
# transformers (bmm)
t = t.view(shape[0]*shape[1]*shape[2], shape[3])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
return mask.view(shape).type(ttype)
"""
# convs
t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])
func = getattr(sys.modules[__name__], pattern, None)
mask = func(t, density)
mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous()
return mask.view(shape).type(ttype)
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(args):
#
# PART1
#
torch.manual_seed(args.seed)
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)
ASP.init_optimizer_for_pruning(optimizer)
step = 0
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps)
torch.save({
'step': step,
'verbosity': args.verbosity,
'seed2': args.seed2,
'pattern': args.pattern,
'whitelist': args.whitelist,
'allow_recompute_mask': args.allow_recompute_mask,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, args.checkpoint_path)
if __name__ == '__main__':
class Args:
verbosity=3
seed = 4873
seed2 = 99875
pattern = "m4n2_2d_best"
whitelist = [torch.nn.Linear]
allow_recompute_mask = True
batch_size = 32
input_features = 8
output_features = 8
hidden_features = 32
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
checkpoint_path = "part1.chkp"
args = Args()
main(args)
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(step, args, model_state_dict, optimizer_state_dict):
#
# PART2
#
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)
ASP.init_optimizer_for_pruning(optimizer)
torch.manual_seed(args.seed2)
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print("Model sparsity is %s" % ("enabled" if ASP.sparsity_is_enabled() else "disabled"))
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)
if __name__ == '__main__':
checkpoint = torch.load("part1.chkp")
class Args:
verbosity = checkpoint['verbosity']
seed = 4873
seed2 = checkpoint['seed2']
pattern = checkpoint['pattern']
whitelist = checkpoint['whitelist']
allow_recompute_mask = checkpoint['allow_recompute_mask']
batch_size = 32
input_features = 8
output_features = 8
hidden_features = 32
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
checkpoint_path = "part1.chkp"
args = Args()
main(checkpoint['step'], args, checkpoint['model_state_dict'], checkpoint['optimizer_state_dict'])
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP
#
# Reference run for checkpointing test (part1 + part2)
#
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(args):
#
# PART1
#
torch.manual_seed(args.seed)
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
ASP.init_model_for_pruning(model, args.pattern, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask)
ASP.init_optimizer_for_pruning(optimizer)
step = 0
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps)
#
# PART 2
#
torch.manual_seed(args.seed2)
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)
if __name__ == '__main__':
class Args:
seed = 4873
seed2 = 99875
pattern = "m4n2_2d_best"
whitelist = [torch.nn.Linear]
allow_recompute_mask = True
batch_size = 32
input_features = 8
output_features = 8
hidden_features = 32
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
checkpoint_path = "part1.chkp"
args = Args()
main(args)
from collections import OrderedDict
import torch
from apex.optimizers import FusedAdam
from sparsity.apex.contrib.sparsity import ASP
def build_model(args):
od = OrderedDict()
for i in range(args.num_layers):
if i == 0:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
elif i == args.num_layers-1:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features])
else:
od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features)
od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features])
return torch.nn.Sequential(od)
def train_step(args, model, optimizer, input_batch, target_batch, step):
predicted_target = model(input_batch)
loss = ((predicted_target-target_batch)**2).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
step = step + 1
#print("Step %d :: loss=%e" % (step, loss.item()))
return step
def train_loop(args, model, optimizer, step, num_steps):
for i in range(num_steps):
input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
step = train_step(args, model, optimizer, input_batch, target_batch, step)
return step
def main(args):
model = build_model(args).cuda()
one_ll = next(model.children()).weight
optimizer = FusedAdam(model.parameters())
# only prune linear layers, even though we also support conv1d, conv2d and conv3d
ASP.init_model_for_pruning(model, "m4n2_1d", whitelist=[torch.nn.Linear], allow_recompute_mask=True)
ASP.init_optimizer_for_pruning(optimizer)
step = 0
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights
ASP.compute_sparse_masks()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps)
# recompute sparse masks
ASP.compute_sparse_masks()
# train for a few steps with sparse weights
print("SPARSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)
# turn off sparsity
print("SPARSE :: ",one_ll)
ASP.restore_pruned_weights()
# train for a few steps with dense weights
print("DENSE :: ",one_ll)
step = train_loop(args, model, optimizer, step, args.num_dense_steps_2)
if __name__ == '__main__':
class Args:
batch_size = 32
input_features = 16
output_features = 8
hidden_features = 40
num_layers = 4
num_dense_steps = 2000
num_sparse_steps = 3000
num_sparse_steps_2 = 1000
num_dense_steps_2 = 1500
args = Args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment