Commit fcd9637c authored by gaoqiong's avatar gaoqiong
Browse files

Merge branch 'v0.2.5_develop' into 'main'

v0.2.5

See merge request dcutoolkit/deeplearing/autoawq!2
parents 7724cca1 427f5481
import os
import torch
import gc
import logging
def auto_parallel(args):
model_size = args.model_path.split("-")[-1]
if model_size.endswith("m"):
model_gb = 1
else:
model_gb = float(model_size[:-1])
if model_gb < 20:
n_gpu = 1
elif model_gb < 50:
n_gpu = 4
else:
n_gpu = 8
args.parallel = n_gpu > 1
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if isinstance(cuda_visible_devices, str):
cuda_visible_devices = cuda_visible_devices.split(",")
else:
cuda_visible_devices = list(range(8))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(dev) for dev in cuda_visible_devices[:n_gpu]]
)
logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
return cuda_visible_devices
import torch
from typing import List
Q_BITS = 4
STORAGE_BITS = 32
PACK_NUM = STORAGE_BITS // Q_BITS
ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7]
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device)
imatrix = imatrix.to(torch.int8)
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // PACK_NUM, PACK_NUM)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // PACK_NUM, PACK_NUM, -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def quantize(fmatrix, scales, zeros, group_size):
"""
Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers.
Args:
fmatrix (torch.Tensor): matrix of 16-bit floats
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
imatrix (torch.Tensor): matrix of 4-bit integers
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = torch.round(
(
fmatrix / scales.repeat_interleave(group_size, dim=0)
+ zeros.repeat_interleave(group_size, dim=0)
)
)
imatrix = imatrix.to(torch.int8) & 0x0F
return imatrix
def dequantize(imatrix, scales, zeros, group_size):
"""
Dequantizes a 4-bit integer matrix into a float matrix.
Args:
imatrix (torch.Tensor): matrix of 4-bit integers
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
fmatrix (torch.Tensor): matrix of 16-bit floats
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = imatrix.to(torch.int8) & 0x0F
fmatrix = (
imatrix - zeros.repeat_interleave(group_size, dim=0)
) * scales.repeat_interleave(group_size, dim=0)
fmatrix = fmatrix.to(torch.float16)
return fmatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = ORDINAL_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is ordinal packing order
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, PACK_NUM)[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view(PACK_NUM, -1)[order, :].view(imatrix.shape)
return imatrix
def awq_to_exllama(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros
import gc
import torch
import accelerate
def get_module_by_name_suffix(model, module_name: str):
for name, module in model.named_modules():
if name.endswith(module_name):
return module
def simple_dispatch_model(model, device_map):
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {
"cpu",
"disk",
}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(
m, execution_device=main_device, prev_module_hook=prev_hook
)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1:
get_module_by_name_suffix(
model, cpu_offload_group[0][0]
)._hf_hook.prev_module_hook = prev_hook
for n, d in device_map.items():
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map
return model
def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, value)
def clear_memory(weight=None):
if weight is not None:
del weight
gc.collect()
torch.cuda.empty_cache()
def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
memory_pct = (
memory_used
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
* 100
)
return memory_pct
def get_best_device():
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda:0"
else:
return "cpu"
def get_lowest_memory_device_index():
device = None
curr_device_memory_pct = 0
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
if device is None or device_memory_pct < curr_device_memory_pct:
device = device_index
curr_device_memory_pct = device_memory_pct
return device
# Examples
## Basic Quantization
AWQ performs zero point quantization down to a precision of 4-bit integers.
You can also specify other bit rates like 3-bit, but some of these options may lack kernels
for running inference.
Notes:
- Some models like Falcon is only compatible with group size 64.
- To use Marlin, you must specify zero point as False and version as Marlin.
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
quant_path = 'mistral-instruct-v0.2-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
```
### Custom Data
This includes an example function that loads either wikitext or dolly.
Note that currently all samples above 512 in length are discarded.
```python
from datasets import load_dataset
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Define data loading methods
def load_dolly():
data = load_dataset('databricks/databricks-dolly-15k', split="train")
# concatenate data
def concatenate_data(x):
return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']}
concatenated = data.map(concatenate_data)
return [text for text in concatenated["text"]]
def load_wikitext():
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")
return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20]
# Quantize
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext())
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
```
### GGUF Export
This computes AWQ scales and appliesthem to the model without running real quantization.
This keeps the quality of AWQ because theweights are applied but skips quantization
in order to make it compatible with other frameworks.
Step by step:
- `quantize()`: Compute AWQ scales and apply them
- `save_pretrained()`: Saves a non-quantized model in FP16
- `convert.py`: Convert the Huggingface FP16 weights to GGUF FP16 weights
- `quantize`: Run GGUF quantization to get real quantized weights, in this case 4-bit.
```python
import os
import subprocess
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'mistralai/Mistral-7B-v0.1'
quant_path = 'mistral-awq'
llama_cpp_path = '/workspace/llama.cpp'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" }
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
# NOTE: We avoid packing weights, so you cannot use this model in AutoAWQ
# after quantizing. The saved model is FP16 but has the AWQ scales applied.
model.quantize(
tokenizer,
quant_config=quant_config,
export_compatible=True
)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
# GGUF conversion
print('Converting model to GGUF...')
llama_cpp_method = "q4_K_M"
convert_cmd_path = os.path.join(llama_cpp_path, "convert.py")
quantize_cmd_path = os.path.join(llama_cpp_path, "quantize")
if not os.path.exists(llama_cpp_path):
cmd = f"git clone https://github.com/ggerganov/llama.cpp.git {llama_cpp_path} && cd {llama_cpp_path} && make LLAMA_CUBLAS=1 LLAMA_CUDA_F16=1"
subprocess.run([cmd], shell=True, check=True)
subprocess.run([
f"python {convert_cmd_path} {quant_path} --outfile {quant_path}/model.gguf"
], shell=True, check=True)
subprocess.run([
f"{quantize_cmd_path} {quant_path}/model.gguf {quant_path}/model_{llama_cpp_method}.gguf {llama_cpp_method}"
], shell=True, check=True)
```
## Basic Inference
To run inference, you often want to run with `fuse_layers=True` to get the claimed speedup in AutoAWQ.
Additionally, consider setting `max_seq_len` (default: 2048) as this will be the maximum context that the model can hold.
Notes:
- You can specify `use_exllama_v2=True` to enable ExLlamaV2 kernels during inference.
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
```
### Transformers
You can also load an AWQ model by using AutoModelForCausalLM, just make sure you have AutoAWQ installed.
Note that not all models will have fused modules when loading from transformers.
See more [documentation here](https://huggingface.co/docs/transformers/main/en/quantization#awq).
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
# NOTE: Must install from PR until merged
# pip install --upgrade git+https://github.com/younesbelkada/transformers.git@add-awq
model_id = "casperhansen/mistral-7b-instruct-v0.1-awq"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda:0"
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
text = "[INST] What are the basic steps to use the Huggingface transformers library? [/INST]"
tokens = tokenizer(
text,
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
```
### vLLM
You can also load AWQ models in [vLLM](https://github.com/vllm-project/vllm).
```python
import asyncio
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
model_path = "casperhansen/mixtral-instruct-awq"
# prompting
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?",
prompt_template = "[INST] {prompt} [/INST]"
# sampling params
sampling_params = SamplingParams(
repetition_penalty=1.1,
temperature=0.8,
max_tokens=512
)
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# async engine args for streaming
engine_args = AsyncEngineArgs(
model=model_path,
quantization="awq",
dtype="float16",
max_model_len=512,
enforce_eager=True,
disable_log_requests=True,
disable_log_stats=True,
)
async def generate(model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer):
tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids
outputs = model.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=1,
prompt_token_ids=tokens,
)
print("\n** Starting generation!\n")
last_index = 0
async for output in outputs:
print(output.outputs[0].text[last_index:], end="", flush=True)
last_index = len(output.outputs[0].text)
print("\n\n** Finished generation!\n")
if __name__ == '__main__':
model = AsyncLLMEngine.from_engine_args(engine_args)
asyncio.run(generate(model, tokenizer))
```
### LLaVa (multimodal)
AutoAWQ also supports the LLaVa model. You simply need to load an
AutoProcessor to process the prompt and image to generate inputs for the AWQ model.
```python
import requests
import torch
from PIL import Image
from awq import AutoAWQForCausalLM
from transformers import AutoProcessor
quant_path = "ybelkada/llava-1.5-7b-hf-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, safetensors=True, device_map={"": 0})
processor = AutoProcessor.from_pretrained(quant_path)
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Generate output
generation_output = model.generate(
**inputs,
max_new_tokens=512
)
print(processor.decode(generation_output[0], skip_special_tokens=True))
```
\ No newline at end of file
# AutoAWQ
AutoAWQ pushes ease of use and fast inference speed into one package. In the following documentation,
you will learn how to quantize and run inference.
Example inference speed (RTX 4090, Ryzen 9 7950X, 64 tokens):
- Vicuna 7B (GEMV kernel): 198.848 tokens/s
- Mistral 7B (GEMM kernel): 156.317 tokens/s
- Mistral 7B (ExLlamaV2 kernel): 188.865 tokens/s
- Mixtral 46.7B (GEMM kernel): 93 tokens/s (2x 4090)
## Installation notes
- Install: `pip install autoawq`.
- Your torch version must match the build version, i.e. you cannot use torch 2.0.1 with a wheel that was built with 2.2.0.
- For AMD GPUs, inference will run through ExLlamaV2 kernels without fused layers. You need to pass the following arguments to run with AMD GPUs:
```python
model = AutoAWQForCausalLM.from_quantized(
...,
fuse_layers=False,
use_exllama_v2=True
)
```
## Supported models
The detailed support list:
| Models | Sizes |
| -------- | --------------------------- |
| LLaMA-2 | 7B/13B/70B |
| LLaMA | 7B/13B/30B/65B |
| Mistral | 7B |
| Vicuna | 7B/13B |
| MPT | 7B/30B |
| Falcon | 7B/40B |
| OPT | 125m/1.3B/2.7B/6.7B/13B/30B |
| Bloom | 560m/3B/7B/ |
| GPTJ | 6.7B |
| Aquila | 7B |
| Aquila2 | 7B/34B |
| Yi | 6B/34B |
| Qwen | 1.8B/7B/14B/72B |
| BigCode | 1B/7B/15B |
| GPT NeoX | 20B |
| GPT-J | 6B |
| LLaVa | 7B/13B |
| Mixtral | 8x7B |
| Baichuan | 7B/13B |
| QWen | 1.8B/7B/14/72B |
\ No newline at end of file
# Auto and Base model classes in AutoAWQ
View the documentation of the main classes of AutoAWQ models below.
::: awq.models.auto.AutoAWQForCausalLM
::: awq.models.base.BaseAWQForCausalLM
# AutoAWQ examples
Please see the docs for more thorough examples. In this folder, you will only find the
very basic examples of quantization, inference, and training.
\ No newline at end of file
import time
import torch
import argparse
import numpy as np
import pandas as pd
from awq import AutoAWQForCausalLM
from awq.models.base import BaseAWQForCausalLM
from transformers import AutoTokenizer, GenerationConfig, LogitsProcessor, LogitsProcessorList
class TimeMeasuringLogitsProcessor(LogitsProcessor):
def __init__(self):
self.token_times = [time.time()]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
"""The logit processor is called after the model forward."""
# cuda runs async operates, so we synchronize for accurate time measurement
torch.cuda.synchronize()
# measure time
start_time = time.time()
self.token_times.append(start_time)
return scores
def get_prefill_duration(self):
return self.token_times[1] - self.token_times[0]
def get_decode_durations(self):
token_times = self.token_times[1:]
token_durations = [token_times[i + 1] - token_times[i] for i in range(len(token_times) - 1)]
return token_durations
def warmup(model):
warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
torch.mm(warm_up,warm_up)
def generate_torch(model, input_ids, n_generate):
context_time = 0
generate_time = []
with torch.inference_mode():
for i in range(n_generate):
torch.cuda.synchronize()
start = time.time()
if i == 0:
# prefill context
inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device)
else:
# decode tokens
inputs = torch.as_tensor(token, device=next(model.parameters()).device)
out = model(inputs, use_cache=True)
torch.cuda.synchronize()
token = out[0][:, -1].max(1)[1].unsqueeze(1)
if i == 0:
context_time += time.time() - start
else:
generate_time.append(time.time() - start)
return context_time, generate_time
def generate_hf(model: BaseAWQForCausalLM, input_ids, n_generate):
generation_config = GenerationConfig(
min_new_tokens=n_generate,
max_new_tokens=n_generate,
use_cache=True,
forced_eos_token_id=-100,
eos_token_id=-100,
)
time_processor = TimeMeasuringLogitsProcessor()
model.generate(
input_ids,
generation_config=generation_config,
logits_processor=LogitsProcessorList([time_processor]),
)
context_time = time_processor.get_prefill_duration()
generate_time = time_processor.get_decode_durations()
return context_time, generate_time
def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_size, no_safetensors, pretrained):
print(f" -- Loading model...")
if pretrained:
model = AutoAWQForCausalLM.from_pretrained(
model_path,
safetensors=not no_safetensors,
device_map="cuda",
torch_dtype=torch.float16,
)
else:
model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, fuse_layers=True,
max_seq_len=n_generate, batch_size=batch_size,
safetensors=not no_safetensors
)
print(f" -- Warming up...")
warmup(model)
print(f" -- Generating {n_generate} tokens, {input_ids.shape[1]} in context...")
try:
context_time, generate_time = generator(model, input_ids, n_generate)
successful_generate = True
except RuntimeError as ex:
if 'cuda out of memory' in str(ex).lower():
successful_generate = False
else:
raise RuntimeError(ex)
total_memory_used = 0
memory_pct = 100
if successful_generate:
# number of tokens in context / time for processing context * batch size
prefill_tokens_per_second = round(input_ids.shape[1] / context_time * batch_size, 2)
# 1 second / median time per token in seconds * batch size
decode_tokens_per_second = round(1 / np.median(generate_time) * batch_size, 2)
print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second")
print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second")
for device in range(torch.cuda.device_count()):
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
total_memory_used += memory_used
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
print(f" ** Max Memory (device: {device}): {memory_used:.2f} GB ({memory_pct:.2f}%)")
else:
prefill_tokens_per_second = 'OOM'
decode_tokens_per_second = 'OOM'
if pretrained:
version = "FP16"
else:
version = model.quant_config.version
return {
"Batch Size": batch_size,
"Prefill Length": input_ids.shape[1],
"Decode Length": n_generate,
"Prefill tokens/s": prefill_tokens_per_second,
"Decode tokens/s": decode_tokens_per_second,
"Memory (VRAM)": f"{total_memory_used:.2f} GB ({memory_pct:.2f}%)"
}, version
def main(args):
rounds = [
{"context": 32, "n_generate": 32},
{"context": 64, "n_generate": 64},
{"context": 128, "n_generate": 128},
{"context": 256, "n_generate": 256},
{"context": 512, "n_generate": 512},
{"context": 1024, "n_generate": 1024},
{"context": 2048, "n_generate": 2048},
{"context": 4096, "n_generate": 4096},
]
if args.generator == "torch":
generator = generate_torch
elif args.generator == "hf":
generator = generate_hf
else:
raise ValueError(f"Unknown generator method passed: {args.generator}")
all_stats = []
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
for settings in rounds:
input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])).cuda()
stats, model_version = run_round(
generator,
args.model_path,
args.quant_file,
settings["n_generate"],
input_ids,
args.batch_size,
args.no_safetensors,
args.pretrained
)
all_stats.append(stats)
if stats["Prefill tokens/s"] == 'OOM':
break
df = pd.DataFrame(all_stats)
print('GPU:', torch.cuda.get_device_name())
print('Model:', args.model_path)
print('Version:', model_version)
print(df.to_markdown(index=False))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="casperhansen/mistral-7b-instruct-v0.1-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
parser.add_argument("--no_safetensors", default=False, action="store_true", help="Use for disabling safetensors")
parser.add_argument("--generator", type=str, default="torch", choices=["torch", "hf"], help="weights filename")
parser.add_argument("--pretrained", default=False, action="store_true", help="Measure pretrained model.")
args = parser.parse_args()
main(args)
import argparse
from lm_eval import evaluator
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from awq.evaluation import (
evaluate_perplexity,
eval_librispeech,
eval_mmlu,
eval_humaneval,
eval_kl_divergence,
)
def run_eval(
model_path, quant_file, device, tasks, task_batch_size, task_n_shot,
task_use_pretrained, pretrained_safetensors
):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
tasks = tasks.split(',')
# Load model
if len(tasks) == 1 and tasks[0] != "mmlu" and tasks[0] != "librispeech":
if task_use_pretrained:
model = AutoAWQForCausalLM.from_pretrained(model_path, safetensors=pretrained_safetensors)
else:
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load adapter
if len(tasks) == 1 and tasks[0] == 'wikitext':
evaluate_perplexity(model.model, tokenizer)
elif len(tasks) == 1 and tasks[0] == 'librispeech':
eval_librispeech(model_path)
elif len(tasks) == 1 and tasks[0] == 'mmlu':
eval_mmlu(model_path, task_n_shot, task_batch_size, device, task_use_pretrained)
elif len(tasks) == 1 and tasks[0] == 'humaneval':
eval_humaneval(model, tokenizer)
elif len(tasks) == 1 and tasks[0] == 'kldiv':
eval_kl_divergence(model.model, model.model, tokenizer, seqlen=1024)
else:
# Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
model=model,
tasks=tasks,
batch_size=task_batch_size,
no_cache=True,
num_fewshot=task_n_shot,
)
print(evaluator.make_table(results))
if __name__ == '__main__':
"""
- Run perplexity of quantized model:
python examples/eval.py --model_path casperhansen/mistral-7b-instruct-v0.1-awq
- Run perplexity unquantized FP16 model:
python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5
- Run MMLU of quantized model:
python examples/eval.py --model_path TheBloke/zephyr-7B-beta-AWQ --tasks mmlu --n_shot 1 --batch_size 4
"""
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--quant_file', default='', type=str, help='Path to quantized AWQ model file')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument("--use_pretrained", default=False, action='store_true',
help="Pass '--use_pretrained' to use a pretrained model running FP16")
parser.add_argument("--pretrained_safetensors", default=False, action='store_true',
help="Load safetensors for FP16 model")
parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--n_shot', type=int, default=0)
args = parser.parse_args()
run_eval(
args.model_path, args.quant_file, args.device,
args.tasks, args.batch_size, args.n_shot, args.use_pretrained,
args.pretrained_safetensors
)
\ No newline at end of file
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/llama-3-8b-instruct-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
chat = [
{"role": "system", "content": "You are a concise assistant that helps answer questions."},
{"role": "user", "content": prompt},
]
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
tokens = tokenizer.apply_chat_template(
chat,
return_tensors="pt"
).cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=64,
eos_token_id=terminators
)
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
quant_path = 'mistral-instruct-v0.2-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
import datasets
from awq import AutoAWQForCausalLM
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import get_peft_model, LoraConfig, TaskType
def prepare_split(tokenizer):
data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train")
prompt_template = "<s>[INST] {prompt} [/INST] {output}</s>"
def format_prompt(x):
return prompt_template.format(
prompt=x["instruction"],
output=x["output"]
)
data = data.map(
lambda x: {"text": format_prompt(x)},
).select_columns(["text"])
data = data.map(lambda x: tokenizer(x["text"]), batched=True)
return data
model_path = "TheBloke/Mistral-7B-v0.1-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
# Prepare data
data_train = prepare_split(tokenizer)
# Config Lora
lora_config = LoraConfig(
r=4,
lora_alpha=8,
lora_dropout=0.5,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)
model = get_peft_model(model.model, lora_config)
model.print_trainable_parameters()
training_arguments = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=1,
optim="adamw_torch",
num_train_epochs=1,
learning_rate=1e-4,
evaluation_strategy="no",
save_strategy="epoch",
save_steps=100,
logging_steps=50,
eval_steps=None,
load_best_model_at_end=False
)
trainer = Trainer(
model=model,
train_dataset=data_train,
args=training_arguments,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
trainer.save_model("output")
\ No newline at end of file
site_name: AutoAWQ
repo_name: casper-hansen/AutoAWQ
repo_url: https://github.com/casper-hansen/AutoAWQ
nav:
- index.md
- Examples: examples.md
- Reference:
- reference/index.md
markdown_extensions:
toc:
permalink: true
markdown.extensions.codehilite:
guess_lang: false
admonition: null
codehilite: null
extra: null
pymdownx.superfences:
custom_fences:
- name: mermaid
class: mermaid
format: !!python/name:pymdownx.superfences.fence_code_format ''
pymdownx.tabbed:
alternate_style: true
pymdownx.tilde: null
attr_list: null
md_in_html: null
plugins:
search: null
mkdocstrings:
handlers:
python:
paths: [.]
options:
extensions:
- griffe_typingdoc
show_root_heading: true
show_if_no_docstring: true
inherited_members: true
members_order: source
separate_signature: true
unwrap_annotated: true
filters:
- '!^_'
merge_init_into_class: true
docstring_section_style: spacy
signature_crossrefs: true
show_symbol_type_heading: true
show_symbol_type_toc: true
theme:
name: material
palette:
- media: '(prefers-color-scheme: light)'
scheme: default
primary: teal
accent: amber
toggle:
icon: material/lightbulb
name: Switch to dark mode
- media: '(prefers-color-scheme: dark)'
scheme: slate
primary: teal
accent: amber
toggle:
icon: material/lightbulb-outline
name: Switch to light mode
features:
- search.suggest
- search.highlight
- content.tabs.link
- navigation.indexes
- content.tooltips
- navigation.path
- content.code.annotate
- content.code.copy
- content.code.select
- navigation.tabs
icon:
repo: fontawesome/brands/github-alt
\ No newline at end of file
zstandard
transformers==4.42.3
\ No newline at end of file
#!/bin/bash
# Set variables
AWQ_VERSION="0.2.5"
RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ/releases/tags/v${AWQ_VERSION}"
# Create a directory to download the wheels
mkdir -p dist
cd dist
# Download all the wheel files from the GitHub release
# excluding ones with '+cu' (%2B is + but encoded)
curl -s $RELEASE_URL | \
jq -r ".assets[].browser_download_url" | \
grep '\.whl' | \
grep -v '%2Bcu' | \
grep -v '%2Brocm' | \
xargs -n 1 -P 4 wget
# Rename the wheels from 'linux_x86_64' to 'manylinux_x86_64'
for file in *linux_x86_64.whl; do
mv "$file" "$(echo $file | sed 's/linux_x86_64/manylinux2014_x86_64/')"
done
cd ..
import os
import torch
import platform
import requests
import subprocess
from pathlib import Path
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension
from typing import Optional, Union
def get_latest_kernels_version(repo):
"""
Get the latest version of the kernels from the github repo.
"""
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
data = response.json()
tag_name = data["tag_name"]
version = tag_name.replace("v", "")
return version
def get_kernels_whl_url(
gpu_system_version,
release_version,
python_version,
platform,
architecture,
):
"""
Get the url for the kernels wheel file.
"""
return f"https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v{release_version}/autoawq_kernels-{release_version}+{gpu_system_version}-cp{python_version}-cp{python_version}-{platform}_{architecture}.whl"
def get_sha(pytorch_root: Union[str, Path]) -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=pytorch_root).decode('ascii').strip()
except Exception:
return 'Unknown'
def get_abi():
try:
command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stdout.strip()
abi = "abi" + output.split(" ")[-1]
return abi
except Exception:
return 'abiUnknown'
def get_version_add(sha: Optional[str] = None) -> str:
version=''
autoawq_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(autoawq_root, "awq"), "__init__.py")
if sha != 'Unknown':
if sha is None:
sha = get_sha(autoawq_root)
version = 'git' + sha[:7]
# abi
version += "." + get_abi()
# dtk version
if os.getenv("ROCM_PATH"):
rocm_path = os.getenv('ROCM_PATH', "")
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
rocm_version=lines[0][:-2].replace(".", "")
version += ".dtk" + rocm_version
# torch version
version += ".torch" + torch.__version__[:5]
lines=[]
with open(add_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
lines[1] = "__dcu_version__ = '0.2.5+das1.1.{}'\n".format(version)
with open(add_version_path, encoding="utf-8",mode="w") as file:
file.writelines(lines)
file.close()
def get_version():
get_version_add()
version_file = 'awq/__init__.py'
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__dcu_version__']
AUTOAWQ_VERSION = ""
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda
if CUDA_VERSION:
CUDA_VERSION = "".join(CUDA_VERSION.split("."))[:3]
ROCM_VERSION = os.getenv("ROCM_VERSION", None) or torch.version.hip
if ROCM_VERSION:
if ROCM_VERSION.startswith("5.6"):
ROCM_VERSION = "5.6.1"
elif ROCM_VERSION.startswith("5.7"):
ROCM_VERSION = "5.7.1"
ROCM_VERSION = "".join(ROCM_VERSION.split("."))[:3]
if not PYPI_BUILD:
if CUDA_VERSION:
AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}"
elif ROCM_VERSION:
#version_info = get_version()
AUTOAWQ_VERSION += get_version()#f"+rocm{ROCM_VERSION}"
else:
raise RuntimeError(
"Your system must have either Nvidia or AMD GPU to build this package."
)
common_setup_kwargs = {
"version": AUTOAWQ_VERSION,
"name": "autoawq",
"author": "Casper Hansen",
"license": "MIT",
"python_requires": ">=3.8.0",
"description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
"long_description": (Path(__file__).parent / "README.md").read_text(
encoding="UTF-8"
),
"long_description_content_type": "text/markdown",
"url": "https://github.com/casper-hansen/AutoAWQ",
"keywords": ["awq", "autoawq", "quantization", "transformers"],
"platforms": ["linux", "windows"],
"classifiers": [
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
],
}
requirements = [
"torch>=2.0.1",
"transformers>=4.35.0",
"tokenizers>=0.12.1",
"typing_extensions>=4.8.0",
"accelerate",
"datasets",
"zstandard",
]
try:
if ROCM_VERSION:
import exlv2_ext
else:
import awq_ext
KERNELS_INSTALLED = True
except ImportError:
KERNELS_INSTALLED = False
# kernels can be downloaded from pypi for cuda+121 only
# for everything else, we need to download the wheels from github
if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION):
if CUDA_VERSION and CUDA_VERSION.startswith("12"):
requirements.append("autoawq-kernels")
elif CUDA_VERSION and CUDA_VERSION.startswith("11") or ROCM_VERSION in ["561", "571"]:
gpu_system_version = (
f"cu{CUDA_VERSION}" if CUDA_VERSION else f"rocm{ROCM_VERSION}"
)
kernels_version = get_latest_kernels_version("casper-hansen/AutoAWQ_kernels")
python_version = "".join(platform.python_version_tuple()[:2])
platform_name = platform.system().lower()
architecture = platform.machine().lower()
latest_rocm_kernels_wheels = get_kernels_whl_url(
gpu_system_version,
kernels_version,
python_version,
platform_name,
architecture,
)
requirements.append(f"autoawq-kernels@{latest_rocm_kernels_wheels}")
else:
raise RuntimeError(
"Your system have a GPU with an unsupported CUDA or ROCm version. "
"Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels"
)
force_extension = os.getenv("PYPI_FORCE_TAGS", "0")
if force_extension == "1":
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
common_setup_kwargs["ext_modules"] = [
CUDAExtension(
name="test_kernel",
sources=[],
)
]
setup(
packages=find_packages(),
install_requires=requirements,
extras_require={
"eval": ["lm_eval==0.4.1", "tabulate", "protobuf", "evaluate", "scipy"],
"dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"]
},
**common_setup_kwargs,
)
import torch
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
import awq_ext
from awq.utils.packing_utils import dequantize_gemm
in_features = 4096
out_features = 1792
w_bit = 4
group_size = 128
MAX_INT32 = 0x7fffffff
MIN_INT32 = -MAX_INT32 - 1
qweight = torch.randint(
MIN_INT32,
MAX_INT32,
(in_features, out_features // (32 // w_bit)),
dtype=torch.int32,
device="cuda",
)
qzeros = torch.randint(
MIN_INT32,
MAX_INT32,
(in_features // group_size, out_features // (32 // w_bit)),
dtype=torch.int32,
device="cuda",
)
scales = torch.randn(
(in_features // group_size, out_features),
dtype=torch.float16,
device="cuda",
)
with torch.no_grad():
cuda_out = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
0,
0,
0,
False
)
torch_out = dequantize_gemm(
qweight,
qzeros,
scales,
w_bit,
group_size
)
assert(torch.allclose(cuda_out, torch_out, rtol=0.0001))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment