Commit 460d6d45 authored by Ji Lin's avatar Ji Lin
Browse files

first commit

parents
.DS_Store
data/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
*.pyc
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.pt
**/*.pt
**/*.pyc
*.json
__pycache__
# AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration
**Efficient and accurate** low-bit weight quantization (INT3/4) for LLMs, supporting **instruction-tuned** models and **multi-modal** LMs.
![overview](figures/overview.png)
The current release supports:
- AWQ search for accurate quantization.
- Pre-computed AWQ model zoo for LLMs (LLaMA, OPT, Vicuna, LLaVA; load to generate quantized weights).
- Memory-efficient 4-bit Linear in PyTorch.
- Efficient CUDA kernel implementation for fast inference (support context and decoding stage).
- Examples on 4-bit inference of an instruction-tuned model (Vicuna) and mult-modal LM (LLaVA).
## Contents
- [Install](#install)
- [AWQ Model Zoo](#awq-model-zoo)
- [Examples](#examples)
- [Usage](#usage)
- [Reference](#reference)
## Install
1. Clone this repository and navigate to AWQ folder
```
git clone https://github.com/mit-han-lab/llm-awq
cd llm-awq
```
2. Install Package
```
conda create -n awq python=3.10 -y
conda activate awq
pip install --upgrade pip # enable PEP 660 support
pip install -e .
```
3. Install kernel implementation
```
cd awq/kernels
python setup.py install
```
## AWQ Model Zoo
We provide pre-computed AWQ search results for multiple model families, including LLaMA, OPT, Vicuna, and LLaVA. To get the pre-computed AWQ search results, run:
```bash
# git lfs install # install git lfs if not already
git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
```
The detailed support list:
| Models | Sizes | INT4-g128 | INT3-g128 |
| ------ | --------------------------- | --------- | --------- |
| LLaMA | 7B/13B/30B/65B | ✅ | ✅ |
| OPT | 125m/1.3B/2.7B/6.7B/13B/30B | ✅ | ✅ |
| Vicuna | 7B/13B | ✅ | |
| LLaVA | 13B | ✅ | |
## Examples
AWQ can be easily applied to various LMs thanks to its good generalization, including instruction-tuned models and multi-modal LMs. It provides an easy-to-use tool to reduce the serving cost of LLMs.
Here we provide two examples of AWQ application: Vicuna-7B (chatbot) and LLaVA-13B (visual reasoning) under `./examples` directory. AWQ can easily reduce the GPU memory of model serving and speed up token generation. It provides accurate quantization, providing reasoning outputs. You should be able to observe **memory savings** when running the models with 4-bit weights.
Note that we perform AWQ using only textual calibration data, depsite we are running on multi-modal input. Please refer to `./examples` for details.
![overview](figures/example_vis.jpg)
## Usage
We provide several sample script to run AWQ (please refer to `./scripts`). We use OPT-6.7B as an example.
1. Perform AWQ search and save search results (we already did it for you):
```bash
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/opt-6.7b-w4-g128.pt
```
2. Evaluate the AWQ quantize model on WikiText-2 (simulated pseudo quantization)
```bash
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/opt-6.7b-w4-g128.pt \
--q_backend fake
```
3. Generate real quantized weights (INT4)
```bash
mkdir quant_cache
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/opt-6.7b-w4-g128.pt \
--q_backend real --dump_quant quant_cache/opt-6.7b-w4-g128-awq.pt
```
4. Load and evaluate the real quantized model (now you can see smaller gpu memory usage)
```bash
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_quant quant_cache/opt-6.7b-w4-g128-awq.pt
```
## Reference
If you find AWQ useful or relevant to your research, please kindly cite our paper:
```
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
```
## Related Projects
[SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://github.com/mit-han-lab/smoothquant)
[GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers](https://arxiv.org/abs/2210.17323)
[Vicuna and FastChat](https://github.com/lm-sys/FastChat#readme)
[LLaVA: Large Language and Vision Assistant](https://github.com/haotian-liu/LLaVA)
from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
import torch
import argparse
import os
import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='path of the hf model')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument("--tasks", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
help="enable model parallelism")
parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size")
# quantization config
parser.add_argument('--w_bit', type=int, default=None)
parser.add_argument('--q_group_size', type=int, default=-1)
parser.add_argument('--no_zero_point', action='store_true',
help="disable zero_point")
parser.add_argument('--q_backend', type=str,
default="fake", choices=["fake", "real"])
# save/load real quantized weights
parser.add_argument('--dump_quant', type=str, default=None,
help='save quantized model')
parser.add_argument('--load_quant', type=str, default=None,
help='load quantized model')
# apply/save/load awq
parser.add_argument('--run_awq', action='store_true',
help="perform awq search process")
parser.add_argument('--dump_awq', type=str, default=None,
help="save the awq search results")
parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
if args.auto_parallel:
gpu_list = auto_parallel(args)
# get quantization config (apart from w_bit)
q_config = {
"zero_point": not args.no_zero_point, # by default True
"q_group_size": args.q_group_size, # whether to use group quantization
}
print("Quantization config:", q_config)
# build model and tokenizer
def build_model_and_enc(model_path):
if not os.path.exists(model_path): # look into ssd
raise FileNotFoundError(f"{model_path} not found!")
print(f"* Building model {model_path}")
# all hf model
config = AutoConfig.from_pretrained(model_path)
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if args.load_quant: # directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
torch_dtype=torch.float16)
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model = load_checkpoint_and_dispatch(
model, args.load_quant, device_map="balanced",
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer"]
)
else: # fp16 to quantized
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, **kwargs)
if args.run_awq:
awq_results = run_awq(
model, enc,
w_bit=args.w_bit, q_config=q_config,
n_samples=128, seqlen=512,
)
if args.dump_awq:
torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
# weight quantization
if args.w_bit is not None:
if args.q_backend == "fake":
assert args.dump_quant is None, \
"Need to use real quantization to dump quantized weights"
pseudo_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
elif args.q_backend == "real": # real quantization
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
if args.dump_quant:
print(
f"Saving the quantized model at {args.dump_quant}...")
torch.save(model.cpu().state_dict(), args.dump_quant)
exit(0)
else:
raise NotImplementedError
return model, enc
def main():
if args.output_path is not None and os.path.exists(args.output_path):
# print(f"Results {args.output_path} already generated. Exit.")
print(f"Results {args.output_path} already generated. Overwrite.")
# exit()
if args.dump_awq and os.path.exists(args.dump_awq):
print(f"Found existing AWQ results {args.dump_awq}, exit.")
exit()
# a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path)
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
if args.tasks is not None:
task_names = args.tasks.split(",")
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=task_names,
batch_size=args.batch_size,
no_cache=True,
num_fewshot=args.num_fewshot,
)
print(evaluator.make_table(results))
if args.output_path is not None:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
# otherwise cannot save
results["config"]["model"] = args.model_path
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2)
if __name__ == '__main__':
main()
#pragma once
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
}
#include <torch/extension.h>
torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters);
#include <torch/extension.h>
#include "gemm_cuda.h"
#include "dequantize.cuh"
#include <cuda_fp16.h>
#include <c10/cuda/CUDAGuard.h>
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (128 + 8)];
__shared__ half scaling_factors_shared[128];
__shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128);
half A_shared_warp[8];
half B_shared_warp[32];
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 128;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
// TODO: Haotian: blockIdx.y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx.y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx.y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 2
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx.y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx.y) % j_factors1) * (128 / 8)
+ ((int)threadIdx.x) % (128 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx.y) % j_factors1) * (128)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C
+ blockIdx.z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx.y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx.z >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx.z;
__syncthreads();
// TODO: Haotian: blockIdx.y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / 128 * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / 128 * (OC));
/*
if (blockIdx.z == 0 && blockIdx.y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx.y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx.z == 0 && blockIdx.y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx.y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
// blockIdx.x: i_factors[0] * j_factors[0]
// blockIdx.y: i_factors[1] * j_factors[1]
if (num_out_channels % 128 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 128");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks(1, (num_out_feats + 16 - 1) / 16 * j_factors1, split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
return _out_feats.sum(0);
}
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "gemm_cuda.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gemm_forward_cuda", &gemm_forward_cuda, "our sparse conv kernel");
}
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17", "-keep"],
}
setup(
name="f16s4_gemm",
packages=find_packages(),
ext_modules=[
CUDAExtension(
name="f16s4_gemm",
sources=["pybind.cpp", "gemm_cuda_gen.cu"],
extra_compile_args=extra_compile_args,
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
\ No newline at end of file
import torch
import torch.nn as nn
from .quantizer import pseudo_quantize_tensor
import gc
__all__ = ["auto_clip_block"]
# weight quantization
@torch.no_grad()
def auto_clip_layer(w, input_feat, n_bit, q_config,
n_grid=20,
max_shrink=0.5,
n_sample_token=512):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = q_config["q_group_size"] if q_config["q_group_size"] > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 # prevent OOM
assert w.shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, n_bit=n_bit, **q_config)
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
del input_feat
del org_out
gc.collect()
torch.cuda.empty_cache()
return best_max_val.squeeze(1)
@torch.no_grad()
def auto_clip_block(module,
w_bit, q_config,
input_feat):
named_linears = {name: m for name,
m in module.named_modules() if isinstance(m, nn.Linear)}
clip_list = []
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in ["q_", "k_"]]):
continue
max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config)
clip_list.append((name, max_val))
return clip_list
@torch.no_grad()
def apply_clip(module, clip_list):
from ..utils.module import get_op_by_name
for name, max_val in clip_list:
layer = get_op_by_name(module, name)
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
import torch
import torch.nn as nn
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from ..utils.module import get_op_by_name, get_op_name
__all__ = ["auto_scale_block", "apply_scale"]
@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
if q_group_size > 0:
weight = weight.view(-1, q_group_size)
scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
scale = scale.view(org_shape)
scale = scale.mean(0)
return scale
@torch.no_grad()
def get_act_scale(x):
return x.abs().view(-1, x.shape[-1]).mean(0)
@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(ln.weight.device)
ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None:
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in ln.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device)
fc1.weight.div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
fc2.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for p in fc2.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def auto_scale_block(module, module_kwargs,
w_bit, q_config,
input_feat):
from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if w_bit is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(
p, n_bit=w_bit, **q_config,
).detach()
else:
def w_quantize_func(p): return p
if "use_cache" in module_kwargs:
module_kwargs.pop("use_cache")
# find the best scale ratio
def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci
# x: n, ci
x = x.to(next(block.parameters()).device)
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale(
weight, q_group_size=q_config.get("q_group_size", -1))
with torch.no_grad():
org_out = block(x, **kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
x_max = get_act_scale(x)
best_error = float('inf')
best_ratio = -1
best_scales = None
n_grid = 20
history = []
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1))
fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs)
if isinstance(out, tuple):
out = out[0]
loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_ratio = ratio
best_scales = scales
block.load_state_dict(org_sd)
if best_ratio == -1:
print(history)
raise Exception
# print(best_ratio)
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()
def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
# module2inspect: if given, we will check the output diff of this module instead of layers
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
scales = _search_module_scale(module2inspect, layers, inp, kwargs)
# prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
scales_list = [] # return the searched scales
if isinstance(module, OPTDecoderLayer):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))
elif isinstance(module, LlamaDecoderLayer):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
else:
raise NotImplementedError(f"{type(module)} not supported yet!")
return scales_list
def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales)
else:
raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
import torch
import torch.nn as nn
import tqdm
import gc
import functools
from collections import defaultdict
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from .auto_scale import auto_scale_block, apply_scale
from .auto_clip import auto_clip_block, apply_clip
__all__ = ["run_awq"]
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_blocks(model):
if isinstance(model, LlamaForCausalLM):
layers = model.model.layers
elif isinstance(model, OPTForCausalLM):
layers = model.model.decoder.layers
else:
raise NotImplementedError(type(model))
return layers
@torch.no_grad()
def run_awq(
model, enc,
w_bit, q_config,
n_samples=512, seqlen=512,
auto_scale=True, mse_range=True,
# some configs for ablation study
calib_data="pileval",
):
from ..utils.calib_data import get_calib_dataset
from ..utils.module import append_str_prefix, get_op_name
layers = get_blocks(model)
samples = get_calib_dataset(
data=calib_data, tokenizer=enc, n_samples=n_samples, block_size=seqlen)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps.append(inp)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass
layers[0] = layers[0].module # restore
inps = inps[0]
gc.collect()
torch.cuda.empty_cache()
awq_results = {
"scale": [],
"clip": [],
}
# solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i]
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
layer, layer_kwargs,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,
)
apply_scale(layer, scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
if mse_range:
clip_list = auto_clip_block(layer,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def apply_awq(model, awq_results):
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
import math
import torch
import torch.nn as nn
import f16s4_gemm # with CUDA kernels
class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev))
if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
else:
self.bias = None
@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
awq_linear.scales = scales.clone().half()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
pack_num = 32 // awq_linear.w_bit
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device)
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = f16s4_gemm.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
import torch
import torch.nn as nn
from tqdm import tqdm
import gc
EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
zero_point=True, q_group_size=-1,
inplace=False,
get_scale_zp=False
):
org_w_shape = w.shape
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (n_bit - 1) - 1
min_int = - 2 ** (n_bit - 1)
scales = max_val / max_int
zeros = 0
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
if inplace:
((w.div_(scales).round_().add_(zeros)).clamp_(
min_int, max_int).sub_(zeros)).mul_(scales)
else:
w = (torch.clamp(torch.round(w / scales) +
zeros, min_int, max_int) - zeros) * scales
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape)
if get_scale_zp:
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else:
return w
@torch.no_grad()
def pseudo_quantize_model_weight(
model, w_bit, q_config,
):
from .pre_quant import get_blocks, get_named_linears
layers = get_blocks(model)
for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
named_linears = get_named_linears(layers[i])
for n, m in named_linears.items():
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
@torch.no_grad()
def real_quantize_model_weight(
model, w_bit, q_config,
init_only=False
):
from .qmodule import WQLinear
from .pre_quant import get_blocks, get_named_linears
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = get_blocks(model)
for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
layer = layers[i]
named_linears = get_named_linears(layer)
for name, module in named_linears.items():
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
else:
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], q_linear)
else:
setattr(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
\ No newline at end of file
import torch
from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
if data == "pileval":
dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train")
else:
raise NotImplementedError
dataset = dataset.shuffle(seed=42)
samples = []
n_run = 0
for data in dataset:
line = data["text"]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
continue
samples.append(sample)
n_run += 1
if n_run == n_samples:
break
# now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
print(f" * Split into {n_split} blocks")
return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]
import transformers
import torch
from lm_eval.base import BaseLM
import fnmatch
class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
super().__init__()
assert isinstance(batch_size, int)
self.model_name = model_name
self.model = model
self.model.eval()
self.tokenizer = tokenizer
# assert isinstance(self.tokenizer, (
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
# )), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = batch_size
self._max_length = max_length
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length != -1:
return self._max_length
if hasattr(self.model.config, 'n_ctx'):
return self.model.config.n_ctx
elif hasattr(self.model.config, 'max_position_embeddings'):
return self.model.config.max_position_embeddings
elif hasattr(self.model.config, 'n_positions'):
return self.model.config.n_positions
elif 'bloom' in self.model_name:
return 2048
elif 'llama' in self.model_name:
return 2048 # TODO: did not check this
else:
print(self.model.config)
raise NotImplementedError
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return "cuda"
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
if isinstance(self.model, transformers.models.t5.modeling_t5.T5ForConditionalGeneration):
dec_inps = torch.cat(
[
torch.tensor(
self.model.generation_config.decoder_start_token_id,
)
.tile(len(inps), 1)
.to(inps),
inps,
],
dim=1,
)
kwargs = {"decoder_input_ids": dec_inps,}
else:
kwargs = {}
out = self.model(inps, **kwargs)[0]
if "opt" in self.model_name: # there are a few extra tokens in opt, which we should omit
return out[:, :, :50257]
else:
return out # [:, :, :self.tokenizer.vocab_size]
def _model_generate(self, context, max_length, eos_token_id):
return self.model.generate(
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
)
def get_op_by_name(module, op_name):
# get the op by its name relative to the module
for name, m in module.named_modules():
if name == op_name:
return m
raise ValueError(f"Cannot find op {op_name} in module {module}")
def get_op_name(module, op):
# get the name of the op relative to the module
for name, m in module.named_modules():
if m is op:
return name
raise ValueError(f"Cannot find op {op} in module {module}")
def append_str_prefix(x, prefix):
if isinstance(x, str):
return prefix + x
elif isinstance(x, tuple):
return tuple([append_str_prefix(y, prefix) for y in x])
elif isinstance(x, list):
return [append_str_prefix(y, prefix) for y in x]
else:
return x
\ No newline at end of file
import os
import torch
import gc
def auto_parallel(args):
model_size = args.model_path.split("-")[-1]
if model_size.endswith("m"):
model_gb = 1
else:
model_gb = float(model_size[:-1])
if model_gb < 20:
n_gpu = 1
elif model_gb < 50:
n_gpu = 4
else:
n_gpu = 8
args.parallel = n_gpu > 1
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if isinstance(cuda_visible_devices, str):
cuda_visible_devices = cuda_visible_devices.split(",")
else:
cuda_visible_devices = list(range(8))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(dev) for dev in cuda_visible_devices[:n_gpu]])
print("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
return cuda_visible_devices
# AWQ Examples
Here we provide two AWQ examples, applying to:
- [Vicuna-7B](https://github.com/lm-sys/FastChat), a chatbot with instruction-tuning
- [LLaVA-13B](https://github.com/lm-sys/FastChat), a visual LM for multi-modal applications like visual reasoning.
Here are some example output from the two demos. You should able to observe memory saving when running the demos in 4-bit. Please check the notebooks for details.
![overview](../figures/example_vis.jpg)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment