"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "ac63c4545e8f55fb7c79aa0e1668f204dc0f27e4"
Unverified Commit 89edb819 authored by ChongyuNVIDIA's avatar ChongyuNVIDIA Committed by GitHub
Browse files

Add the permutation related support as the extension for asp lib. (#1194)

* Add the permutation related support as the extension for asp lib.

* [Fix] Track the permutation sequence for progressive channel swap strategy.

* Fix the corner case that one layer is not sparse, but need to apply permutation due to its siblings.

* Fix the deprecated functions in ASP unit tests.

* Fix the sparsity info typo in ASP lib.

* [Enhancement] Set the identical random seed for all GPUs to make sure the same results generated in permutation search.

* Update the README.md with identical random seed setting and NeurIPS info.

* Integrate the Pybind11 enhancement of permutation search into ASP lib.
parent 79c01877
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python. This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
## Importing ASP ## Importing ASP
``` ```
from apex.contrib.sparsity import ASP from apex.contrib.sparsity import ASP
``` ```
...@@ -10,11 +11,13 @@ from apex.contrib.sparsity import ASP ...@@ -10,11 +11,13 @@ from apex.contrib.sparsity import ASP
## Initializing ASP ## Initializing ASP
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference: Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
``` ```
In the context of a typical PyTorch training loop, it might look like this: In the context of a typical PyTorch training loop, it might look like this:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
...@@ -27,6 +30,7 @@ for epoch in range(epochs): ...@@ -27,6 +30,7 @@ for epoch in range(epochs):
torch.save(...) torch.save(...)
``` ```
The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step.
## Generate a Sparse Network ## Generate a Sparse Network
...@@ -42,7 +46,6 @@ The following approach serves as a guiding example on how to generate a pruned m ...@@ -42,7 +46,6 @@ The following approach serves as a guiding example on how to generate a pruned m
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above). In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
``` ```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint) model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
...@@ -72,7 +75,60 @@ ASP.compute_sparse_masks() ...@@ -72,7 +75,60 @@ ASP.compute_sparse_masks()
A more thorough example can be found in `./test/toy_problem.py`. A more thorough example can be found in `./test/toy_problem.py`.
## Advanced Usage: Channel Permutation
We introduce channel permutations as an advanced method to maximize the accuracy of structured sparse networks. By permuting weight matrices along their channel dimension and adjusting the surrounding layers appropriately, we demonstrate accuracy recovery for even small, parameter-efficient networks, without affecting inference run-time.
The final accuracy has a strong relationship with the quality of permutations. We provide the default algorithms to search for high-quality permutations. The permutation search process can be accelerated by the Apex CUDA extension: `apex.contrib.sparsity.permutation_search_kernels`
If you want to use the GPU to accelerate the permutation search process, we recommend installing Apex with permutation search CUDA extension via
```
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--permutation_search" ./
```
If you want to disable the permutation search process, please pass the `allow_permutation=False` to `init_model_for_pruning` function. For example:
```
ASP.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False, allow_permutation=False)
```
Please notice, when using multi-GPUs we should set the identical random seed for all GPUs to make sure the same results generated in permutation search. The library has implemented the `set_identical_seed` function in `permutation_lib.py`, and be called in ASP library. We still suggest the users to set the identical random seed when using multi-GPUs in their code, the example code is as follows:
```
import torch
import numpy
import random
torch.manual_seed(identical_seed)
torch.cuda.manual_seed_all(identical_seed)
numpy.random.seed(identical_seed)
random.seed(identical_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
```
## Reference Papers
More details about sparsity support on the NVIDIA Ampere GPU with Sparse Tensor Cores can refer to our [white paper](https://arxiv.org/abs/2104.08378).
```
@article{mishra2021accelerating,
title={Accelerating sparse deep neural networks},
author={Mishra, Asit and Latorre, Jorge Albericio and Pool, Jeff and Stosic, Darko and Stosic, Dusan and Venkatesh, Ganesh and Yu, Chong and Micikevicius, Paulius},
journal={arXiv preprint arXiv:2104.08378},
year={2021}
}
```
The details about sparsity with permutation can refer to our [paper](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html) published in *Thirty-fifth Conference on Neural Information Processing Systems* (**NeurIPS 2021**):
```
@article{pool2021channel,
title={Channel Permutations for N: M Sparsity},
author={Pool, Jeff and Yu, Chong},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
```
import types import types
import torch import torch
from .sparse_masklib import create_mask from .sparse_masklib import create_mask
from .permutation_lib import Permutation
torchvision_imported=True torchvision_imported=True
try: try:
...@@ -9,6 +10,11 @@ except ImportError: ...@@ -9,6 +10,11 @@ except ImportError:
print("[ASP][Warning] torchvision cannot be imported.") print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported=False torchvision_imported=False
import json
import os
import string
import time
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names): def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = [] eligible_modules_list = []
for name, mod in model.named_modules(): for name, mod in model.named_modules():
...@@ -18,19 +24,25 @@ def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallow ...@@ -18,19 +24,25 @@ def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallow
eligible_modules_list.append((name, mod)) eligible_modules_list.append((name, mod))
return eligible_modules_list return eligible_modules_list
class ASP: class ASP:
__model = None __model = None
__verbosity = 0 __verbosity = 0
__optimizer = None __optimizer = None
__sparse_parameters = [] __sparse_parameters = []
__calculate_mask = None __calculate_mask = None
__allow_permutation = True
__all_parameters = []
__save_permutation_graph = False
__permutation_output_dir = ''
@classmethod @classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3, verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[], allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False, custom_layer_dict={}): allow_recompute_mask=False, custom_layer_dict={},
allow_permutation=True):
"""Call this method to modify your model to take advantage of sparse matrix multiplication. """Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA, Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA. it does not enable use of sparse MMA.
...@@ -63,12 +75,14 @@ class ASP: ...@@ -63,12 +75,14 @@ class ASP:
allow_recompute_mask If True, stores pruned values so that dense weights can be restored. allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage. Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']} custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning.
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM. [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe.
""" """
assert (cls.__model is None), "ASP has been initialized already." assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model cls.__model = model
cls.__verbosity = verbosity cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
if isinstance(mask_calculator, str): if isinstance(mask_calculator, str):
def create_mask_from_pattern(param): def create_mask_from_pattern(param):
...@@ -91,6 +105,28 @@ class ASP: ...@@ -91,6 +105,28 @@ class ASP:
for module_type in whitelist: for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype() assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim
for module_name, module in model.named_modules():
module_type_str = str(type(module)).split("\'")[1]
if module_type_str == 'torch.nn.modules.container.Sequential' or module_type_str.startswith('torchvision.models'):
# filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'
continue
for p_name, p in module.named_parameters():
cls.__all_parameters.append((module_name, module, p_name, p))
if module_type_str == 'torch.nn.modules.batchnorm.BatchNorm2d':
# need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters
module_mean_name = module_name + '.running_mean'
module_var_name = module_name + '.running_var'
for param_key in model.state_dict():
if module_mean_name == param_key or module_var_name == param_key:
cls.__all_parameters.append((module_name, module, param_key.split(".")[-1], model.state_dict()[param_key]))
# add the __permutation_output_dir field to save the intermediate results for permutation
cls.__permutation_output_dir = '.'
# Set the corresponding params from ASP class to the Permutation class
Permutation.set_permutation_params_from_asp(cls.__model, cls.__sparse_parameters, cls.__all_parameters)
# Set the identical random seed for all GPUs to make sure the same results generated in permutation search
Permutation.set_identical_seed()
# find all sparse modules, extract sparse parameters and decorate # find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module): def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)] sparse_parameters = sparse_parameter_list[type(module)]
...@@ -123,6 +159,19 @@ class ASP: ...@@ -123,6 +159,19 @@ class ASP:
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names): for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module) add_sparse_attributes(name, sparse_module)
@classmethod
def already_init_asp_model(cls):
"""Call this method to check whether ASP has been initialized already.
"""
if cls.__model is None:
if cls.__verbosity >= 3:
print("[ASP] ASP has not been initialized.")
return False
else:
if cls.__verbosity >= 3:
print("[ASP] ASP has been initialized already.")
return True
@classmethod @classmethod
def init_optimizer_for_pruning(cls, optimizer): def init_optimizer_for_pruning(cls, optimizer):
"""Call this method to monkey patch optimizer step function so that masks can be applied to """Call this method to monkey patch optimizer step function so that masks can be applied to
...@@ -157,6 +206,38 @@ class ASP: ...@@ -157,6 +206,38 @@ class ASP:
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None. If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
""" """
with torch.no_grad(): with torch.no_grad():
if cls.__allow_permutation:
# Step 1: use the Torch.FX library to build the graph
# Step 2: permutation search with the customized kernel
# Notice: need to use the single GPU to build the Torch.FX graph
# The simplest without user intervention:
# A. try to import with the distributed mode of the original model
# B. if meet the error, import with the none-distributed mode of the original model
start_time_build_offline_permutation_graph = time.perf_counter()
try:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model.module, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on distributed model.")
except AttributeError:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on none-distributed model.")
duration_build_offline_permutation_graph = time.perf_counter() - start_time_build_offline_permutation_graph
print("[compute_sparse_masks] Take {:.4f} seconds to finish build_offline_permutation_graph function.".format(duration_build_offline_permutation_graph))
# Step 3: off-line permutation to avoid the runtime overhead in deployment
if success_in_build_offline_permutation_graph:
start_time_apply_offline_permutation = time.perf_counter()
try:
Permutation.apply_offline_permutation(cls.__model.module, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on distributed model.")
except AttributeError:
Permutation.apply_offline_permutation(cls.__model, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on none-distributed model.")
duration_apply_offline_permutation = time.perf_counter() - start_time_apply_offline_permutation
print("[compute_sparse_masks] Take {:.4f} seconds to finish apply_offline_permutation function.\n".format(duration_apply_offline_permutation))
else:
print("[compute_sparse_masks] skip applying offline permutation because there is no valid offline_permutation_fx_graph.")
# Finally, permutation search and off-line permutation is done, give the model back to ASP to generate the normal structured sparse mask
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled # restore dense parameter if allow_recompute_mask is enabled
...@@ -170,7 +251,7 @@ class ASP: ...@@ -170,7 +251,7 @@ class ASP:
p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2: if cls.__verbosity >= 2:
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype))) print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0-100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))
@classmethod @classmethod
def restore_pruned_weights(cls): def restore_pruned_weights(cls):
...@@ -215,3 +296,17 @@ class ASP: ...@@ -215,3 +296,17 @@ class ASP:
cls.init_optimizer_for_pruning(optimizer) cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks() cls.compute_sparse_masks()
@classmethod
def set_permutation_saving_params(cls, allow_permutation=True, save_permutation_graph=False, permutation_output_dir='.'):
"""This function is used to set the permutation saving related parameters in ASP class and inside of the Permutation class."""
print("\n[ASP][set_permutation_saving_param] Set permutation saving related parameters")
print("\n[set_permutation_saving_param] Set permutation saving related parameters")
cls.__allow_permutation = allow_permutation
print("[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation))
cls.__save_permutation_graph = save_permutation_graph
print("[set_permutation_saving_param]\t Save permutation graphs: {}".format(cls.__save_permutation_graph))
cls.__permutation_output_dir = permutation_output_dir
print("[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(cls.__permutation_output_dir))
Permutation.set_permutation_saving_params(allow_permutation, save_permutation_graph, permutation_output_dir)
This diff is collapsed.
#include <stdio.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;
#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert %d: %s %s %d\n", (int)code, cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
__device__ float group_2_to_4(float4 vals)
{
vals.x = fabs(vals.x);
vals.y = fabs(vals.y);
vals.z = fabs(vals.z);
vals.w = fabs(vals.w);
float sum0 = vals.x + vals.y;
float sum1 = vals.x + vals.z;
float sum2 = vals.x + vals.w;
float sum3 = vals.y + vals.z;
float sum4 = vals.y + vals.w;
float sum5 = vals.z + vals.w;
float best_sum0 = fmax(sum0, sum1);
float best_sum1 = fmax(sum2, sum3);
float best_sum2 = fmax(sum4, sum5);
float best_sum = fmax(fmax(best_sum0, best_sum1), best_sum2);
return best_sum;
}
inline float* float_ptr_from_numpy(py::array_t<float>& py_float)
{
return (float*)py_float.data();
}
inline unsigned int* uint_ptr_from_numpy(py::array_t<unsigned int>& py_uint)
{
return (unsigned int*)py_uint.data();
}
__global__ void subset_sum_after_2_to_4(float* matrix,
unsigned int rows,
unsigned int cols,
unsigned int start_col,
unsigned int end_col,
float* output)
{
// vectorize
float4* mat4 = (float4*) matrix;
cols /= 4;
start_col /= 4;
end_col /= 4;
// each thread in a block takes some number of rows
size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);
size_t row_offset = num_rows * threadIdx.x;
// each block takes some number of columns
size_t num_cols = (end_col - start_col) / gridDim.x;
size_t col_offset = num_cols * blockIdx.x;
start_col += col_offset;
end_col = start_col + num_cols;
float sum = 0.0f;
for ( unsigned int r = row_offset; r < row_offset + num_rows; ++r ) {
if (r < rows) {
for ( unsigned int c = start_col; c < end_col; c++ ) {
sum += group_2_to_4(mat4[r * cols + c]);
}
}
}
atomicAdd(output, sum);
}
// build the entire permute map at once
// each block handles one group of stripes
// each threads in the block handle all handle the same permutation at the same time on different rows before moving to the next permutation
__global__ void build_permute_map(float* matrix,
unsigned int rows,
unsigned int cols,
unsigned int* stripes,
unsigned int group_width,
unsigned int* permutations,
unsigned int num_permutations,
unsigned int perm_length,
float* output,
unsigned int* best_indices)
{
// vectorize
float4* mat4 = (float4*) matrix;
cols /= 4;
// each block handles a group of stripes
unsigned int* stripe_group = (unsigned int*)&stripes[blockIdx.x*group_width];
// shared memory: 32 threads each need 16*2
extern __shared__ float pm_shared[32][32];
float4* local_stripes = (float4*)&pm_shared[threadIdx.x];
float* local_columns = (float*) &pm_shared[threadIdx.x];
float4* permuted_stripes = (float4*) &local_stripes[4];
float* permuted_columns = (float*) &local_columns[16];
// each thread handles all permutations in the row before moving on to the next row
size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);
size_t row_offset = num_rows * threadIdx.x;
for ( unsigned int r = row_offset; r < row_offset + num_rows; ++r) {
if (r >= rows)
break;
// load a row into smem
for ( unsigned int s = 0; s < group_width; ++s) {
unsigned int const stripe = stripe_group[s];
local_stripes[s] = mat4[r*cols+stripe];
}
for ( unsigned int p = 0; p < num_permutations; ++p) {
unsigned int* permutation = &permutations[p*perm_length];
float sum = 0.0f;
// permute
#pragma unroll 4
for ( unsigned int c = 0; c < group_width*4; ++c) {
permuted_columns[c] = local_columns[permutation[c]];
}
// sum 2:4
for ( unsigned int s = 0; s < group_width; ++s) {
sum += group_2_to_4(permuted_stripes[s]);
}
// update the running sum for this stripe group's permutation
atomicAdd(&output[blockIdx.x*num_permutations + p], sum);
}
}
// at this point, each permutation's sum in this stripe group has been calculated
// now, find the best option
__syncthreads();
if (threadIdx.x == 0) {
unsigned int best_permutation = 0;
float best_magnitude = output[blockIdx.x*num_permutations];
float base_magnitude = best_magnitude;
//#pragma unroll 32
for (unsigned int p = 1; p < num_permutations; ++p) {
float magnitude = output[blockIdx.x*num_permutations+p];
if (magnitude > best_magnitude) {
best_permutation = p;
best_magnitude = magnitude;
}
}
output[blockIdx.x*num_permutations] = best_magnitude - base_magnitude;
best_indices[blockIdx.x] = best_permutation;
}
}
void free_sum_after_2_to_4_memory(float** dmatrix,
float** dresult)
{
cudaFree(*dmatrix);
cudaFree(*dresult);
}
int set_up_sum_after_2_to_4_memory(float** dmatrix,
unsigned int rows,
unsigned int cols,
float** dresult)
{
static unsigned int setupRows = 0;
static unsigned int setupCols = 0;
static bool allocated = false;
int fresh_allocation = 0;
if (!allocated ||
setupRows != rows ||
setupCols != cols)
{
if (allocated)
free_sum_after_2_to_4_memory(dmatrix, dresult);
gpuErrchk(cudaMalloc( (void**) dmatrix, rows*cols*sizeof(float)));
gpuErrchk(cudaMalloc( (void**) dresult, sizeof(float)));
setupRows = rows;
setupCols = cols;
fresh_allocation = 1;
}
allocated = true;
return fresh_allocation;
}
int run_subset_sum_after_2_to_4(py::array_t<float>& py_matrix,
unsigned int rows,
unsigned int cols,
unsigned int start_col,
unsigned int end_col,
unsigned int blocks,
unsigned int threads,
py::array_t<float>& py_output)
{
static float* d_matrix;
static float* d_result;
int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result);
float* matrix = float_ptr_from_numpy(py_matrix);
float* output = float_ptr_from_numpy(py_output);
gpuErrchk(cudaMemcpy( d_matrix, matrix, rows*cols*sizeof(float), cudaMemcpyHostToDevice ));
gpuErrchk(cudaMemset( d_result, 0, sizeof(float)));
subset_sum_after_2_to_4<<<blocks, threads>>>(d_matrix, rows, cols, start_col, end_col, d_result);
gpuErrchk(cudaDeviceSynchronize());
gpuErrchk(cudaMemcpy( output, d_result, sizeof(float), cudaMemcpyDeviceToHost ));
return 0;
}
void set_up_permute_map_memory(float** dmatrix,
unsigned int rows,
unsigned int cols,
unsigned int** dstripes,
unsigned int num_groups,
unsigned int group_width,
unsigned int** dpermutations,
unsigned int num_permutations,
unsigned int perm_length,
float** doutput,
unsigned int** dindices,
float** hresult,
unsigned int** hindices)
{
static unsigned int setUpRows = 0;
static unsigned int setUpCols = 0;
static unsigned int setUpGroupWidth = 0;
static unsigned int setUpNumGroups = 0;
static unsigned int setUpNumPerms = 0;
static unsigned int setUpPermLength = 0;
if (setUpRows != rows ||
setUpCols != cols) {
if (*dmatrix != NULL) { gpuErrchk(cudaFree(*dmatrix)); *dmatrix = NULL; }
gpuErrchk(cudaMalloc( (void**) dmatrix, rows*cols*sizeof(float)));
}
if (setUpGroupWidth < group_width ||
setUpNumGroups < num_groups) {
if (*dstripes != NULL) { gpuErrchk(cudaFree(*dstripes)); *dstripes = NULL; }
gpuErrchk(cudaMalloc( (void**) dstripes, num_groups*group_width*sizeof(unsigned int)));
if (setUpNumGroups < num_groups) {
if (*dindices != NULL) { gpuErrchk(cudaFree(*dindices)); *dindices = NULL; }
gpuErrchk(cudaMalloc( (void**) dindices, num_groups*sizeof(unsigned int)));
if (*hindices != NULL) { free(*hindices); *hindices = NULL; }
*hindices = (unsigned int*) malloc (num_groups*sizeof(unsigned int));
}
}
if (setUpNumPerms < num_permutations ||
setUpPermLength < perm_length) {
if (*dpermutations != NULL) { gpuErrchk(cudaFree(*dpermutations)); *dpermutations = NULL; }
gpuErrchk(cudaMalloc( (void**) dpermutations, perm_length*num_permutations*sizeof(unsigned int)));
}
if (setUpNumPerms < num_permutations ||
setUpNumGroups < num_groups) {
if (*doutput != NULL) { gpuErrchk(cudaFree(*doutput)); *doutput = NULL; }
gpuErrchk(cudaMalloc( (void**) doutput, num_permutations*num_groups*sizeof(float)));
if (*hresult != NULL) { free(*hresult); *hresult = NULL; }
*hresult = (float*) malloc(num_permutations*num_groups*sizeof(float));
}
setUpRows = rows;
setUpCols = cols;
setUpGroupWidth = group_width;
setUpNumGroups = num_groups;
setUpNumPerms = num_permutations;
setUpPermLength = perm_length;
}
int run_build_permute_map(py::array_t<float>& py_matrix,
unsigned int rows,
unsigned int cols,
py::array_t<unsigned int>& py_stripes,
unsigned int num_groups,
unsigned int group_width,
py::array_t<unsigned int>& py_permutations,
//unsigned int num_permutations,
unsigned int perm_length,
py::array_t<float>& py_improvements,
py::array_t<unsigned int>& py_best_indices)
{
static float* d_matrix = NULL;
static unsigned int* d_stripes = NULL;
static unsigned int* d_permutations = NULL;
static float* d_output = NULL;
static unsigned int* d_indices = NULL;
static float* hresult = NULL;
static unsigned int* hindices = NULL;
//const unsigned int cols = py_matrix.size() / rows;
//const unsigned int num_groups = py_stripes.size() / group_width;
//const unsigned int perm_length = group_width * 4; // 2:4 sparsity - each stripe in the group is 4 elements wide
const unsigned int num_permutations = py_permutations.size() / perm_length;
const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40;
const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH;
const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH;
const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0);
set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups,MAX_GROUPS_PER_LAUNCH), group_width, &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices);
float* matrix = float_ptr_from_numpy(py_matrix);
unsigned int* stripes = uint_ptr_from_numpy(py_stripes);
unsigned int* permutations = uint_ptr_from_numpy(py_permutations);
float* improvements = float_ptr_from_numpy(py_improvements);
unsigned int* best_indices = uint_ptr_from_numpy(py_best_indices);
gpuErrchk(cudaMemcpy( d_matrix, matrix, rows*cols*sizeof(float), cudaMemcpyHostToDevice ));
gpuErrchk(cudaMemcpy( d_permutations, permutations, num_permutations*perm_length*sizeof(unsigned int), cudaMemcpyHostToDevice ));
unsigned int group_offset = 0;
for (unsigned int l = 0; l < launches; ++l)
{
unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch;
gpuErrchk(cudaMemcpy( d_stripes, &stripes[group_offset*group_width], groups_this_launch*group_width*sizeof(unsigned int), cudaMemcpyHostToDevice ));
gpuErrchk(cudaMemset( d_output, 0, groups_this_launch*num_permutations*sizeof(float)));
gpuErrchk(cudaMemset( d_indices, 0, groups_this_launch*sizeof(unsigned int)));
unsigned int shmem = 32*(32)*sizeof(float);
build_permute_map<<<groups_this_launch, 32, shmem>>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations, num_permutations, perm_length, d_output, d_indices);
gpuErrchk(cudaDeviceSynchronize());
gpuErrchk(cudaMemcpy( hresult, d_output, num_permutations*groups_this_launch*sizeof(float), cudaMemcpyDeviceToHost ));
gpuErrchk(cudaMemcpy( hindices, d_indices, groups_this_launch*sizeof(unsigned int), cudaMemcpyDeviceToHost ));
// thread0 stuck the minimum in the first slot of each group
for (unsigned int g = 0; g < groups_this_launch; ++g) {
improvements[group_offset+g] = hresult[g*num_permutations];
best_indices[group_offset+g] = hindices[g];
}
group_offset += groups_this_launch;
}
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("sum_after_2_to_4", &run_subset_sum_after_2_to_4, "matrix sum after applying 2:4 (CUDA)");
m.def("build_permute_map", &run_build_permute_map, "optimize stripe groups (CUDA)");
}
\ No newline at end of file
from .call_permutation_search_kernels import accelerated_search_for_good_permutation
from .permutation_utilities import sum_after_2_to_4
\ No newline at end of file
import numpy as np
from .permutation_utilities import *
from .exhaustive_search import Exhaustive_Search
def accelerated_search_for_good_permutation(matrix_group, options=None):
"""This function is used to call the permutation search CUDA kernels.
users can provide prefer search strategy by providing a valid 'options' as a dictionary,
or users can implement their customized 'accelerated_search_for_good_permutation' function.
"""
input_matrix = matrix_group.cpu().detach().numpy()
print("\n[accelerated_search_for_good_permutation] input matrix shape: \'{:}\'.".format(input_matrix.shape))
result = np.copy(input_matrix)
# init a sequential permutation search sequence
input_channel_num = matrix_group.size()[1]
permutation_sequence = [n for n in range(input_channel_num)]
duration = 0.0
if options == None:
options = {}
if 'strategy' not in options: # right now, the default permutation search strategy is: 'exhaustive' search
options['strategy'] = 'exhaustive'
print("[accelerated_search_for_good_permutation] the permutation strategy is: \'{:} search\'.".format(options['strategy']))
# define sub options for each search strategy
if options['strategy'] == 'exhaustive':
# right now, the default options for 'exhaustive' search is: 'exhaustive,8,100'
if 'stripe_group_size' not in options:
options['stripe_group_size'] = 8
if 'escape_attempts' not in options:
options['escape_attempts'] = 100
elif options['strategy'] == 'progressive channel swap':
# just swaps meaningful channels, keeping the good swaps, until the search time limit expires.
if 'progressive_search_time_limit' not in options:
options['progressive_search_time_limit'] = 60
if 'improvement_threshold' not in options:
options['improvement_threshold'] = 1e-9
# execute the requested strategy
if options['strategy'] == 'exhaustive':
result, duration, permutation_sequence = Exhaustive_Search(result, stripe_group_size=options['stripe_group_size'], escape_attempts=options['escape_attempts'])
elif options['strategy'] == 'progressive channel swap':
real_swap_num = 0
start_time = time.perf_counter()
while time.perf_counter() - start_time < options['progressive_search_time_limit']:
src = np.random.randint(result.shape[1])
dst = np.random.randint(result.shape[1])
src_group = int(src/4)
dst_group = int(dst/4)
if src_group == dst_group: # channel swapping within a stripe does nothing
continue
new_sum, improvement = try_swap(result, dst, src)
if improvement > options['improvement_threshold']:
result[...,[src,dst]] = result[...,[dst,src]]
permutation_sequence[src], permutation_sequence[dst] = permutation_sequence[dst], permutation_sequence[src]
real_swap_num += 1
duration = time.perf_counter() - start_time
print("\tFinally swap {} channel pairs until the search time limit expires.".format(real_swap_num))
elif options['strategy'] == 'user defined': # need to get the permutated matrix (result) by applying customized permutation search function
print("[accelerated_search_for_good_permutation] Use the user customized permutation search function!")
else:
print("[accelerated_search_for_good_permutation] Cannot find the implementation of the required strategy!")
print("[accelerated_search_for_good_permutation] Take {:.4f} seconds to search the permutation sequence.".format(duration))
# In the new version of Exhaustive_Search function, there’s no need to use the find_permutation(result, input_matrix) function
# to recover the permutation sequence applied to the input_matrix to get the result separately any more.
#start_time_find_permutation = time.perf_counter()
#permutation_sequence = find_permutation(result, input_matrix)
#duration_find_permutation = time.perf_counter() - start_time_find_permutation
#print("[accelerated_search_for_good_permutation] Take {:.4f} seconds to finish find_permutation function.".format(duration_find_permutation))
#print("[accelerated_search_for_good_permutation] The permutation sequence is: {:}".format(permutation_sequence))
#print("[accelerated_search_for_good_permutation] The length of permutation sequence is: {:}".format(len(permutation_sequence)))
return permutation_sequence
from .permutation_utilities import *
################################################################################################################
# Exhaustive
# Try them all
# - order of columns within a group doesn't matter
# - order of groups doesn't matter
# - we can eliminate effective duplicates by defining aunique combination to be a sorted list of sorted groups
################################################################################################################
####################################################################
# generate unique permutations
####################################################################
# check if adding a column index to a current permutation would keep it in canonical form
# assumes that perm is in canonical form already!
def is_canonical(perm, col):
# if it's a new group
if len(perm) % 4 == 0:
# every column ID < col needs to be in the permutation already
for val in range(col):
if val not in perm:
return False
# this new group needs to be sorted w.r.t. the previous group
return col > perm[-4]
# not a new group, just check to see if it will still be sorted
return col > perm[-1]
# recursive: build a unique permutation one column index at a time
def generate_unique_combinations(built_permutation, remaining_columns, full_permutation_list, group_width):
# base case: nothing else to add
if len(remaining_columns) == 0:
full_permutation_list.append(np.copy(built_permutation))
if len(full_permutation_list) % 1000000 == 0:
print(f"{len(full_permutation_list)} unique permutations found so far")
# still more choices to make, so add each remaining column in turn column if it keeps everything sorted
else:
for c in range(len(remaining_columns)):
# to satisfy our immutables (values within groups are sorted, groups are globally sorted),
# only add this column if either:
# it's starting a new group and is larger than the previous group's first entry
# OR
# it's larger than the last value in the built_permutation
col_to_add = remaining_columns[c]
if is_canonical(built_permutation, col_to_add):
# add the column to the running permutation, remove it from remaining columns
built_permutation.append(col_to_add)
remaining_columns.pop(c)
# recurse
generate_unique_combinations(built_permutation, remaining_columns, full_permutation_list, group_width)
# remove the most recent column and put it back on the remaining column list where we found it (sorted)
remaining_columns.insert(c, built_permutation.pop(-1))
import pickle
import os.path
from os import path
master_unique_permutation_list = {}
def generate_all_unique_combinations(C, M, must_use_all_groups = False):
global master_unique_permutation_list
if len(master_unique_permutation_list) == 0 and path.exists("master_list.pkl"):
with open("master_list.pkl","rb") as cache:
master_unique_permutation_list = pickle.load(cache)
if (C,M) not in master_unique_permutation_list:
full_permutation_list = []
generate_unique_combinations([0], [c for c in range(1,C)], full_permutation_list, M)
master_unique_permutation_list[(C,M)] = full_permutation_list
with open("master_list.pkl", "wb") as cache:
pickle.dump(master_unique_permutation_list, cache)
unique_permutations = master_unique_permutation_list[(C,M)]
return unique_permutations
# analytical solution
import math
def predict_unique_combinations(C, M):
assert(C%M==0)
G = int(C/M)
return int(int(math.factorial(C)) / (int(math.pow(math.factorial(M),G)) * math.factorial(G)))
#################################################################
# exhaustively try all unique permutations
#################################################################
# exhaustively search the entire matrix
def search_matrix(matrix, group_width):
# give up quickly if we'd go on forever
prediction = predict_unique_combinations(matrix.shape[1], group_width)
best_permutation = [c for c in range(matrix.shape[1])]
if prediction > 1e10:
print(f"There are {prediction} unique combinations with {matrix.shape[1]} columns and a group width of {group_width}, not searching.")
return matrix, prediction, best_permutation
start_time = time.perf_counter()
full_permutation_list = generate_all_unique_combinations(matrix.shape[1], group_width)
# found them, now try them
best_improvement = 0.0
base_sum = sum_after_2_to_4(matrix)
for i in range(1,len(full_permutation_list)):
permutation = full_permutation_list[i]
permuted = matrix[:, permutation]
cur_improvement = sum_after_2_to_4(permuted) - base_sum
if (cur_improvement > best_improvement):
best_improvement = cur_improvement
best_permutation = permutation
seconds = time.perf_counter() - start_time
return matrix[:, best_permutation], seconds, best_permutation, best_improvement
#############
# Stripe group handling
#############
# gather stripes from a larger matrix into a single matrix
def collect_stripes(matrix, stripes, group_width):
subset = np.zeros((matrix.shape[0], len(stripes)*group_width))
#print("[Debug][collect_stripes] matrix shape info: {}".format(matrix.shape))
#print("[Debug][collect_stripes] subset info: {}, {}, {}".format(matrix.shape[0], len(stripes), group_width))
for s,stripe in enumerate(stripes):
#print("[Debug][collect_stripes] s: {}, stripe: {}".format(s, stripe))
subset[...,s*group_width:s*group_width+group_width] = matrix[...,stripe*group_width:stripe*group_width+group_width]
return subset
# apply the stripe group permutation to the entire permutation
def apply_stripe_group_permutation(sgp, stripes, group_width, permutation):
new_permutation = permutation.copy()
for subset_idx in range(len(sgp)):
dst_stripe_idx = stripes[int(subset_idx / group_width)]
dst_col_idx = subset_idx % group_width
subset_val = sgp[subset_idx]
src_stripe_idx = stripes[int(subset_val / group_width)]
src_col_idx = subset_val % group_width
new_permutation[dst_stripe_idx*group_width + dst_col_idx] = permutation[src_stripe_idx*group_width + src_col_idx]
return new_permutation
# generate all possible stripe groups
def generate_stripe_groups(num_stripes, window_size):
stripe_array = [[c] for c in range(num_stripes)]
next_stripe_array = []
for w in range(1, window_size):
for g in range(len(stripe_array)):
start_c = stripe_array[g][w-1]+1
group = stripe_array[g]
for c in range(start_c, num_stripes):
new_group = group.copy()
new_group.append(c)
next_stripe_array.append(new_group)
stripe_array = next_stripe_array
next_stripe_array = []
return set(tuple(stripe_array[g]) for g in range(len(stripe_array)))
# It is not safe to just reset the stripe_set as None here.
# When calling the Exhaustive_Search in E2E search, the stripe_set will not be reset as None.
stripe_set = None
stripe_set_config = None
# build the stripe map
def build_stripe_map(matrix, group_width, window_size, stripe_map, stripe_ids, perm_map, used_stripes):
global stripe_set, stripe_set_config
#print("[Debug][build_stripe_map] Now the stripe_set value is: {}".format(stripe_set))
window_size = int(window_size / group_width)
if stripe_set is None or stripe_set_config is None or stripe_set_config != (group_width, window_size):
num_stripes = int(matrix.shape[1] / group_width)
assert(group_width * num_stripes == matrix.shape[1])
stripe_set = generate_stripe_groups(num_stripes, window_size)
#print("[Debug][build_stripe_map] Update stripe_set value as: {}".format(stripe_set))
stripe_set_config = (group_width, window_size)
# step through each, update the stripe_map/stripe_ids if necessary
updates = 0
use_cuda = use_gpu()
gpu_list = []
gpu_groups = []
for i,s in enumerate(stripe_set):
sg = [] # build the group of stripes, check if any members changed
need_update = i >= len(stripe_map)
for stripe in s:
sg.append(stripe)
if stripe in used_stripes:
need_update = True
# pre-populate if we're building fresh
if i >= len(stripe_map):
stripe_ids.append(sg)
stripe_map.append(0.)
perm_map.append([c for c in range(group_width * window_size)])
# update entries if needed (only stripe_map and perm_map)
if need_update:
updates += 1
if not use_cuda: # do the work here if using the CPU
subset = collect_stripes(matrix, sg, group_width)
sub_result, sub_duration, permutation, improvement = search_matrix(subset, group_width)
stripe_map[i] = improvement
perm_map[i] = permutation
else: # otherwise, just track the work needed to farm off to the GPU
gpu_groups.append(sg)
gpu_list.append(i)
if use_cuda: # if using the GPU, perform the work
matrix_view = np.copy(matrix).astype(np.float32).flatten()
all_permutations = generate_all_unique_combinations(window_size*group_width, group_width)
num_permutations = len(all_permutations)
permutation_view = np.copy(np.asarray(all_permutations)).astype(np.uint32).flatten()
stripe_groups_view = np.asarray(gpu_groups).astype(np.uint32).flatten()
num_gpu_groups = len(gpu_list)
gpu_improvement = np.zeros((num_gpu_groups), dtype=np.float32).flatten()
gpu_permutation = np.zeros((num_gpu_groups), dtype=np.uint32).flatten()
result = permutation_search_cuda_kernels.build_permute_map(matrix_view,
matrix.shape[0],
matrix.shape[1],
stripe_groups_view,
num_gpu_groups,
window_size,
permutation_view,
window_size * group_width,
gpu_improvement,
gpu_permutation)
# put the data where python expects it
for i in range(len(gpu_list)):
stripe_map[gpu_list[i]] = gpu_improvement[i]
perm_map[gpu_list[i]] = all_permutations[gpu_permutation[i]]
return stripe_map, stripe_ids, perm_map
# start performing stripe checks
sm_perturbations = 0
sm_perturbation_limit = 0
def use_stripe_map(matrix, group_width, stripe_map, stripe_ids, perm_map, permutation):
global sm_perturbations, sm_perturbation_limit
used_stripes = []
stripe_groups_optimized = 0
improvement = 0.0
# set the traversal order
ix = np.flip(np.argsort(stripe_map)) # small to large --> large to small
for i in range(len(ix)):
stripe_group_id = ix[i]
perm = perm_map[stripe_group_id].copy()
if stripe_map[stripe_group_id] <= 0.0001:
# perturbations
if len(used_stripes) == 0 and sm_perturbations < sm_perturbation_limit:
sm_perturbations += 1
# use this permutation, but swap two channels from left/right halves to include two stripes, no matter the group size
stripe_group_id = ix[np.random.randint(len(ix))]
perm = perm_map[stripe_group_id].copy()
# a little easier to escape from
src = np.random.randint(int(len(perm)/2))
dst = int(len(perm)/2) + np.random.randint(int(len(perm)/2))
perm[src],perm[dst] = perm[dst],perm[src]
else:
break
stripe_group = stripe_ids[stripe_group_id]
# don't work on stripes we've already touched
touched_stripe = False
for stripe in stripe_group:
if stripe in used_stripes:
touched_stripe = True
if touched_stripe:
continue
# apply the permutation we've already found to this stripe group
subset = collect_stripes(matrix, stripe_group, group_width)
sub_result = subset[...,perm]
permutation = apply_stripe_group_permutation(perm, stripe_group, group_width, permutation)
# scatter the results, track what changed
for s,stripe in enumerate(stripe_group):
# see if this group is in canonical form (entry 0 a multiple of 4, contiguous values))
group = perm[s*group_width:s*group_width+group_width] # columns in this group of the used permutation
changed = False
if group[0] % 4 != 0:
changed = True
for c in range(1,group_width):
if group[c] != group[c-1]+1:
changed = True
break
# if it's not, then it changed
if changed:
used_stripes.append(stripe_group[s])
matrix[...,stripe*group_width:stripe*group_width+group_width] = sub_result[...,s*group_width:s*group_width+group_width]
improvement += stripe_map[stripe_group_id]
stripe_groups_optimized += 1
return matrix, stripe_groups_optimized, stripe_map, stripe_ids, used_stripes, improvement, permutation
# entry point for exhaustive searches - both the entire matrix, as well as stripe groups
def Exhaustive_Search(matrix, stripe_group_size=-1, escape_attempts=0, permutation=None):
global sm_perturbation_limit, sm_perturbations
sm_perturbations = 0
sm_perturbation_limit = escape_attempts
if permutation is None:
permutation = [c for c in range(matrix.shape[1])]
# It is much safer to reset the stripe_set as None in the entry point of Exhaustive_Search
global stripe_set, stripe_set_config
stripe_set = None
stripe_set_config = None
# only support N:4 for now
group_width = 4
result = np.copy(matrix)
# if the matrix is too large for a window size of 12, subdivide, then fix up with a global optimization with a window size of 8
if group_width==4 and stripe_group_size==12 and matrix.shape[1] > 512:
stripe_split = int(matrix.shape[1]/2/group_width)
col_split = stripe_split * group_width
result[:,:col_split], durationL, permutation[:col_split] = Exhaustive_Search(result[:,:col_split], stripe_group_size=stripe_group_size, escape_attempts=escape_attempts, permutation=permutation[:col_split])
result[:,col_split:], durationR, permutation[col_split:] = Exhaustive_Search(result[:,col_split:], stripe_group_size=stripe_group_size, escape_attempts=escape_attempts, permutation=permutation[col_split:])
escape_attempts = max(escape_attempts, 100)*10
result,duration,permutation = Exhaustive_Search(result, stripe_group_size=8, escape_attempts=escape_attempts, permutation=permutation)
return result, durationL+durationR+duration, permutation
# small enough to optimize the entire matrix at once
if stripe_group_size != -1 and stripe_group_size < matrix.shape[1]:
stripe_map = []
stripe_ids = []
perm_map = []
used_stripes = []
optimized_groups_count = 0
agg_improvement = 0.
cur_total_sum = sum_after_2_to_4(result)
# in practice, this work will be cached ahead of time; doing it now.
# (Reading the cached list from disk can take several seconds, which shouldn't be counted against the search, but amortized over every layer in a network)
generate_all_unique_combinations(stripe_group_size, group_width)
start_time = time.perf_counter()
while True:
#print("[Debug][Exhaustive_Search] Before entering the build_stripe_map function.")
#print("[Debug][Exhaustive_Search] Now the stripe_set value is: {}".format(stripe_set))
stripe_map, stripe_ids, perm_map = build_stripe_map(result, group_width, stripe_group_size, stripe_map, stripe_ids, perm_map, used_stripes)
result, stripe_groups_optimized, stripe_map, stripe_ids, used_stripes, improvement, permutation = use_stripe_map(result, group_width, stripe_map, stripe_ids, perm_map, permutation)
# converged?
if len(used_stripes) == 0:
break
duration = time.perf_counter() - start_time
else: # no sliding window, single iteration
print(f"Matrix has {matrix.shape[1]} columns and the search window is only {stripe_group_size}: searching exhaustively")
result, duration, permutation, improvement = search_matrix(matrix, group_width)
return result, duration, permutation
import numpy as np
import time
import ctypes
import subprocess
import os
import math
gpus_tested = False
gpus_found = 0
kernels_found = True
try:
import permutation_search_cuda as permutation_search_cuda_kernels
print(f"Found permutation search CUDA kernels")
except ImportError:
print(f"Could not find permutation search CUDA kernels, falling back to CPU path")
kernels_found = False
def use_gpu(initial_override = True):
global gpus_tested, gpus_found, kernels_found
if not gpus_tested:
if not initial_override:
gpus_tested = True
return False
try:
gpus_found = str(subprocess.check_output(["nvidia-smi", "-L"])).count('UUID')
print(f"Found {gpus_found} gpus")
except:
gpus_found = 0
print(f"Could not find nvidia-smi, please check your cuda installation")
gpus_tested = True
return gpus_found > 0 and kernels_found
##############################################################################################
# pruning utilities
##############################################################################################
## apply 2:4 to some matrix
def apply_2_to_4(matrix):
for row in range(matrix.shape[0]):
for col in range(0,matrix.shape[1],4):
ix = np.argsort(np.abs(matrix[row,col:col+4]))
matrix[row,col+ix[0]] = 0.0
matrix[row,col+ix[1]] = 0.0
return matrix
## find the sum of magnitudes if 2:4 were applied to a matrix
def sum_after_2_to_4(matrix):
#matrix = np.copy(matrix)
cur_sum = 0.0
use_cuda = use_gpu()
if not use_cuda:
start_time = time.perf_counter()
for row in range(matrix.shape[0]):
for col in range(0,matrix.shape[1],4):
ix = np.argsort(np.abs(matrix[row,col:col+4]))
cur_sum += abs(matrix[row,col+ix[2]])
cur_sum += abs(matrix[row,col+ix[3]])
np_elapsed = time.perf_counter() - start_time
else:
matrix = matrix.astype(np.float32)
cuda_sum = np.zeros((1), dtype=np.float32)
start_time = time.perf_counter()
matrix_view = np.copy(matrix).flatten()
sum_view = cuda_sum.flatten()
blocks = max(int(matrix.shape[1]/4/2), 1)
threads = min(max(math.ceil(matrix.shape[0]/4), 1), 1024)
result = permutation_search_cuda_kernels.sum_after_2_to_4(matrix_view,
matrix.shape[0],
matrix.shape[1],
0,
matrix.shape[1],
blocks,
threads,
sum_view)
cuda_elapsed = time.perf_counter() - start_time
#print(cuda_sum, cuda_elapsed, cur_sum, np_elapsed, np_elapsed/cuda_elapsed)
cur_sum = sum_view[0]
return cur_sum
## try swapping columns and tracking magnitude after pruning
def try_swap(matrix, dst, src):
src_base = sum_after_2_to_4(matrix[...,int(src/4)*4:int(src/4)*4+4])
dst_base = sum_after_2_to_4(matrix[...,int(dst/4)*4:int(dst/4)*4+4])
# swap
matrix[...,[src,dst]] = matrix[...,[dst,src]]
# check the Nx4 slices of the swapped columns
src_sum = sum_after_2_to_4(matrix[...,int(src/4)*4:int(src/4)*4+4])
dst_sum = sum_after_2_to_4(matrix[...,int(dst/4)*4:int(dst/4)*4+4])
# swap back
matrix[...,[src,dst]] = matrix[...,[dst,src]]
return src_sum + dst_sum, (src_sum + dst_sum) - (src_base + dst_base)
##############################################################################################
# permutation utilities
##############################################################################################
## find the permutation needed to make matrix A look like matrix B
def find_permutation(A, B):
permutation = []
for col in range(A.shape[1]):
Avals = A[...,col]
for bcol in range(B.shape[1]):
if np.all(Avals - B[...,bcol] == np.zeros(Avals.shape)):
permutation.append(bcol)
break
return permutation
...@@ -55,7 +55,7 @@ def main(args): ...@@ -55,7 +55,7 @@ def main(args):
step = train_loop(args, model, optimizer, step, args.num_dense_steps) step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights # simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity() ASP.compute_sparse_masks()
# train for a few steps with sparse weights # train for a few steps with sparse weights
print("SPARSE :: ",one_ll) print("SPARSE :: ",one_ll)
......
...@@ -50,7 +50,7 @@ def main(step, args, model_state_dict, optimizer_state_dict): ...@@ -50,7 +50,7 @@ def main(step, args, model_state_dict, optimizer_state_dict):
model.load_state_dict(model_state_dict) model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
print("Model sparsity is %s" % ("enabled" if ASP.sparsity_is_enabled() else "disabled")) print("Model sparsity is %s" % ("enabled" if ASP.is_sparsity_enabled() else "disabled"))
# train for a few steps with sparse weights # train for a few steps with sparse weights
print("SPARSE :: ",one_ll) print("SPARSE :: ",one_ll)
......
...@@ -59,7 +59,7 @@ def main(args): ...@@ -59,7 +59,7 @@ def main(args):
step = train_loop(args, model, optimizer, step, args.num_dense_steps) step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights # simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity() ASP.compute_sparse_masks()
# train for a few steps with sparse weights # train for a few steps with sparse weights
print("SPARSE :: ",one_ll) print("SPARSE :: ",one_ll)
......
...@@ -295,6 +295,20 @@ if "--cuda_ext" in sys.argv: ...@@ -295,6 +295,20 @@ if "--cuda_ext" in sys.argv:
) )
) )
if "--permutation_search" in sys.argv:
sys.argv.remove("--permutation_search")
if CUDA_HOME is None:
raise RuntimeError("--permutation_search was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
cc_flag = ['-Xcompiler', '-fPIC', '-shared']
ext_modules.append(
CUDAExtension(name='permutation_search_cuda',
sources=['apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu'],
include_dirs=[os.path.join(this_dir, 'apex', 'contrib', 'sparsity', 'permutation_search_kernels', 'CUDA_kernels')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros + cc_flag}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
sys.argv.remove("--bnp") sys.argv.remove("--bnp")
raise_if_cuda_home_none("--bnp") raise_if_cuda_home_none("--bnp")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment