Unverified Commit 48fc613d authored by Asit's avatar Asit Committed by GitHub
Browse files

Merge pull request #5 from a-maci/a-maci-patch-update-asp-readme

Update README for ASP
parents eb5e96c2 e3794f42
......@@ -15,7 +15,6 @@ ASP.prune_trained_model(model, optimizer)
```
In a typical PyTorch training loop, it might look like this:
```
ASP.prune_trained_model(model, optimizer)
......@@ -34,4 +33,38 @@ The `prune_trained_model` calculates the sparse mask and applies it to the weigh
ASP.compute_sparse_masks()
```
A more thorough example can be found in `./test/toy_problem.py`.
\ No newline at end of file
A more thorough example can be found in `./test/toy_problem.py`.
The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Core in NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.
```
(1) Given a fully trained (dense) network, prune parameter values in 2:4 sparsepattern.
(2) Fine-tune the pruned model with optimization method and hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.
(3) (If required) Quantize the model.
```
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model
from apex.contrib.sparsity import ASP
ASP.prune_trained_model(model, optimizer) #pruned a trained model
x, y = DataLoader(args)
for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model
y_pred = model(x)
loss = criterion(y_pred, y)
lr_scheduler.step()
loss.backward()
optimizer.step()
torch.save(...) # saves the pruned checkpoint with sparsity masks
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment