Unverified Commit e3794f42 authored by Asit's avatar Asit Committed by GitHub
Browse files

Update README for ASP

Added an outline to illustrate our recommended recipe to obtain a pruned model
parent eb5e96c2
...@@ -15,7 +15,6 @@ ASP.prune_trained_model(model, optimizer) ...@@ -15,7 +15,6 @@ ASP.prune_trained_model(model, optimizer)
``` ```
In a typical PyTorch training loop, it might look like this: In a typical PyTorch training loop, it might look like this:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
...@@ -34,4 +33,38 @@ The `prune_trained_model` calculates the sparse mask and applies it to the weigh ...@@ -34,4 +33,38 @@ The `prune_trained_model` calculates the sparse mask and applies it to the weigh
ASP.compute_sparse_masks() ASP.compute_sparse_masks()
``` ```
A more thorough example can be found in `./test/toy_problem.py`. A more thorough example can be found in `./test/toy_problem.py`.
\ No newline at end of file
The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Core in NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.
```
(1) Given a fully trained (dense) network, prune parameter values in 2:4 sparsepattern.
(2) Fine-tune the pruned model with optimization method and hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.
(3) (If required) Quantize the model.
```
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model
from apex.contrib.sparsity import ASP
ASP.prune_trained_model(model, optimizer) #pruned a trained model
x, y = DataLoader(args)
for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model
y_pred = model(x)
loss = criterion(y_pred, y)
lr_scheduler.step()
loss.backward()
optimizer.step()
torch.save(...) # saves the pruned checkpoint with sparsity masks
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment