README.md 3.08 KB
Newer Older
Kirthi Sivamani's avatar
Kirthi Sivamani committed
1
2
# Introduction to ASP

3
This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
Kirthi Sivamani's avatar
Kirthi Sivamani committed
4
5
6
7
8
9
10
11

## Importing ASP
```
from apex.contrib.sparsity import ASP
```

## Initializing ASP

Jay Rodge's avatar
Jay Rodge committed
12
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:
Kirthi Sivamani's avatar
Kirthi Sivamani committed
13
14
15
16
```
ASP.prune_trained_model(model, optimizer)
```

17
In the context of a typical PyTorch training loop, it might look like this:
Kirthi Sivamani's avatar
Kirthi Sivamani committed
18
19
20
21
22
23
24
25
26
27
28
29
```
ASP.prune_trained_model(model, optimizer)

x, y = DataLoader(args)
for epoch in range(epochs):
    y_pred = model(x)
    loss = loss_function(y_pred, y)
    loss.backward()
    optimizer.step()

torch.save(...)
```
30
The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. 
Asit's avatar
Asit committed
31

32
## Generate a Sparse Network
Asit's avatar
Asit committed
33

34
The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.
Asit's avatar
Asit committed
35
36

```
37
(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.
Asit's avatar
Asit committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
(2) Fine-tune  the  pruned  model  with  optimization  method  and  hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.
(3) (If required) Quantize the model.
```

In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).

```

model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model

from apex.contrib.sparsity import ASP     
ASP.prune_trained_model(model, optimizer) #pruned a trained model

x, y = DataLoader(args)
for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model
    y_pred = model(x)
    loss = criterion(y_pred, y)
    lr_scheduler.step()
    loss.backward()
    optimizer.step()

torch.save(...) # saves the pruned checkpoint with sparsity masks 
```
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

## Non-Standard Usage

If your goal is to easily perpare a network for accelerated inference, please follow the recipe above.  However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:

```
ASP.compute_sparse_masks()
```

A more thorough example can be found in `./test/toy_problem.py`.