Unverified Commit bcaa8a36 authored by Casper's avatar Casper Committed by GitHub
Browse files

v0.2.0 (#330)


Co-authored-by: default avatarjinz2014 <7799920+jinz2014@users.noreply.github.com>
Co-authored-by: default avatarJin Z <5zj@cousteau.ftpn.ornl.gov>
parent c69d3b65
...@@ -6,6 +6,7 @@ from awq.modules.linear.marlin import WQLinear_Marlin ...@@ -6,6 +6,7 @@ from awq.modules.linear.marlin import WQLinear_Marlin
from awq.modules.linear.exllama import WQLinear_Exllama from awq.modules.linear.exllama import WQLinear_Exllama
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2 from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
def prepare_correct_devices(next_layer, hidden_states, mask): def prepare_correct_devices(next_layer, hidden_states, mask):
hidden_states = hidden_states.to(next_layer.device) hidden_states = hidden_states.to(next_layer.device)
...@@ -14,6 +15,7 @@ def prepare_correct_devices(next_layer, hidden_states, mask): ...@@ -14,6 +15,7 @@ def prepare_correct_devices(next_layer, hidden_states, mask):
return hidden_states, mask return hidden_states, mask
def prepare_cache(blocks, seqlen: int) -> int: def prepare_cache(blocks, seqlen: int) -> int:
for block in blocks: for block in blocks:
start_pos = block.attn.start_pos start_pos = block.attn.start_pos
...@@ -21,12 +23,15 @@ def prepare_cache(blocks, seqlen: int) -> int: ...@@ -21,12 +23,15 @@ def prepare_cache(blocks, seqlen: int) -> int:
# Reset and avoid retaining state when processing context # Reset and avoid retaining state when processing context
if seqlen > 1 and (will_cache_be_exceeded or start_pos > 0): if seqlen > 1 and (will_cache_be_exceeded or start_pos > 0):
block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=start_pos) block.attn.start_pos = block.attn.cache.roll_kv_n_steps(
start_pos, n=start_pos
)
# Slowly roll out old tokens without performance hit if exceeded during decoding # Slowly roll out old tokens without performance hit if exceeded during decoding
elif seqlen == 1 and will_cache_be_exceeded: elif seqlen == 1 and will_cache_be_exceeded:
block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=100) block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=100)
def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int): def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
# NOTE: from transformers 4.35.0, input_ids includes full context during decoding # NOTE: from transformers 4.35.0, input_ids includes full context during decoding
num_input_tokens = input_ids.shape[-1] num_input_tokens = input_ids.shape[-1]
...@@ -41,18 +46,22 @@ def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int): ...@@ -41,18 +46,22 @@ def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
return input_ids, last_forward_num_tokens + num_new_tokens return input_ids, last_forward_num_tokens + num_new_tokens
def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor): def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor):
mask = None mask = None
if seqlen > 1: if seqlen > 1:
mask = torch.full( mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=device)
(1, 1, seqlen, seqlen), float("-inf"), device=device mask = torch.triu(mask, diagonal=start_pos + 1).type_as(type_as)
)
mask = torch.triu(mask, diagonal=start_pos+ 1).type_as(type_as)
return mask return mask
def fuse_qkv(module, q_proj, k_proj, v_proj): def fuse_qkv(module, q_proj, k_proj, v_proj):
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None bias = (
torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0)
if q_proj.bias is not None
else None
)
if isinstance(q_proj, WQLinear_GEMV): if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV q_linear = WQLinear_GEMV
...@@ -71,45 +80,110 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): ...@@ -71,45 +80,110 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
q_proj.in_features, q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None, q_proj.bias is not None,
next(iter(module.state_dict().values())).device next(iter(module.state_dict().values())).device,
) )
if isinstance(q_proj, WQLinear_GEMV): if isinstance(q_proj, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) qkv_layer.qweight = torch.cat(
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) )
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=0
)
qkv_layer.split_k_iters = q_proj.split_k_iters qkv_layer.split_k_iters = q_proj.split_k_iters
elif isinstance(q_proj, WQLinear_GEMM): elif isinstance(q_proj, WQLinear_GEMM):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat(
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) )
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
elif isinstance(q_proj, WQLinear_Exllama): elif isinstance(q_proj, WQLinear_Exllama):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat(
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) )
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
elif isinstance(q_proj, WQLinear_ExllamaV2): elif isinstance(q_proj, WQLinear_ExllamaV2):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat(
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) )
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
elif isinstance(q_proj, WQLinear_Marlin): elif isinstance(q_proj, WQLinear_Marlin):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat(
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
# workspace is created in post_init # workspace is created in post_init
qkv_layer.bias = bias qkv_layer.bias = bias
for layer in [q_proj, k_proj, v_proj]:
del (layer.qweight, layer.qzeros, layer.scales)
return qkv_layer return qkv_layer
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
def fuse_linears(linears, device, dim=1, operation=torch.cat):
total_out_features = sum([layer.out_features for layer in linears])
fused = WQLinear_GEMM(
linears[0].w_bit,
linears[0].group_size,
linears[0].in_features,
total_out_features,
bias=None,
dev=device,
)
fused.qweight = operation([layer.qweight for layer in linears], dim=dim)
fused.qzeros = operation([layer.qzeros for layer in linears], dim=dim)
fused.scales = operation([layer.scales for layer in linears], dim=dim)
for layer in linears:
del (layer.qweight, layer.qzeros, layer.scales, layer)
return fused
def get_attention_shapes(
attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim
):
if attention_shapes is not None: if attention_shapes is not None:
attention_shapes = attention_shapes attention_shapes = attention_shapes
elif n_kv_heads == 0: elif n_kv_heads == 0:
attention_shapes = { attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,), "cache_v": (
cache_batch_size,
n_heads,
max_seq_len,
head_dim,
),
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,), "cache_k": (
cache_batch_size,
n_heads,
head_dim // 8,
max_seq_len,
8,
),
"xqkv_view": (-1, n_heads, head_dim), "xqkv_view": (-1, n_heads, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0], "xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1], "xk_slice": lambda xqkv: xqkv[:, :, 1],
...@@ -120,26 +194,37 @@ def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_head ...@@ -120,26 +194,37 @@ def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_head
"xk_reshape": (n_heads, head_dim // 8, 8), "xk_reshape": (n_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim), "single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_heads, head_dim), "single_xk_view": (n_heads, head_dim),
"single_xv_view": (n_heads, head_dim) "single_xv_view": (n_heads, head_dim),
} }
else: else:
attention_shapes = { attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,), "cache_v": (
cache_batch_size,
n_kv_heads,
max_seq_len,
head_dim,
),
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,), "cache_k": (
cache_batch_size,
n_kv_heads,
head_dim // 8,
max_seq_len,
8,
),
"xqkv_view": (n_heads + n_kv_heads * 2, head_dim), "xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads], "xq_slice": lambda xqkv: xqkv[:, :, 0:n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)], "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :], "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads:],
"xq_view": (n_heads, head_dim), "xq_view": (n_heads, head_dim),
"xk_view": (n_kv_heads, head_dim), "xk_view": (n_kv_heads, head_dim),
"xv_view": (n_kv_heads, head_dim), "xv_view": (n_kv_heads, head_dim),
"xk_reshape": (n_kv_heads, head_dim // 8, 8), "xk_reshape": (n_kv_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim), "single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_kv_heads, head_dim), "single_xk_view": (n_kv_heads, head_dim),
"single_xv_view": (n_kv_heads, head_dim) "single_xv_view": (n_kv_heads, head_dim),
} }
return attention_shapes return attention_shapes
import torch.nn as nn import torch.nn as nn
def get_named_linears(module): def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_op_by_name(module, op_name): def get_op_by_name(module, op_name):
# get the op by its name relative to the module # get the op by its name relative to the module
for name, m in module.named_modules(): for name, m in module.named_modules():
...@@ -12,10 +14,10 @@ def get_op_by_name(module, op_name): ...@@ -12,10 +14,10 @@ def get_op_by_name(module, op_name):
def set_op_by_name(layer, name, new_module): def set_op_by_name(layer, name, new_module):
levels = name.split('.') levels = name.split(".")
if len(levels) > 1: if len(levels) > 1:
mod_ = layer mod_ = layer
for l_idx in range(len(levels)-1): for l_idx in range(len(levels) - 1):
if levels[l_idx].isdigit(): if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])] mod_ = mod_[int(levels[l_idx])]
else: else:
...@@ -43,6 +45,7 @@ def append_str_prefix(x, prefix): ...@@ -43,6 +45,7 @@ def append_str_prefix(x, prefix):
else: else:
return x return x
def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert): def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
if modules_to_not_convert is None: if modules_to_not_convert is None:
return linear_layers return linear_layers
......
...@@ -79,6 +79,7 @@ def unpack_reorder_pack(qweight, qzeros, bits): ...@@ -79,6 +79,7 @@ def unpack_reorder_pack(qweight, qzeros, bits):
return qweight, qzeros return qweight, qzeros
def dequantize_gemm(qweight, qzeros, scales, bits, group_size): def dequantize_gemm(qweight, qzeros, scales, bits, group_size):
# Unpack the qweight and qzeros tensors # Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits) iweight, izeros = unpack_awq(qweight, qzeros, bits)
......
...@@ -23,6 +23,7 @@ def auto_parallel(args): ...@@ -23,6 +23,7 @@ def auto_parallel(args):
else: else:
cuda_visible_devices = list(range(8)) cuda_visible_devices = list(range(8))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(dev) for dev in cuda_visible_devices[:n_gpu]]) [str(dev) for dev in cuda_visible_devices[:n_gpu]]
)
logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
return cuda_visible_devices return cuda_visible_devices
...@@ -8,6 +8,7 @@ def get_module_by_name_suffix(model, module_name: str): ...@@ -8,6 +8,7 @@ def get_module_by_name_suffix(model, module_name: str):
if name.endswith(module_name): if name.endswith(module_name):
return module return module
def simple_dispatch_model(model, device_map): def simple_dispatch_model(model, device_map):
from accelerate.hooks import add_hook_to_module, AlignDevicesHook from accelerate.hooks import add_hook_to_module, AlignDevicesHook
...@@ -18,7 +19,10 @@ def simple_dispatch_model(model, device_map): ...@@ -18,7 +19,10 @@ def simple_dispatch_model(model, device_map):
return model return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model) tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {
"cpu",
"disk",
}:
main_device = "cpu" main_device = "cpu"
else: else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
...@@ -27,10 +31,14 @@ def simple_dispatch_model(model, device_map): ...@@ -27,10 +31,14 @@ def simple_dispatch_model(model, device_map):
prev_hook = None prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group): for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n) m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) _, prev_hook = accelerate.cpu_offload_with_hook(
m, execution_device=main_device, prev_module_hook=prev_hook
)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook # set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1: if len(cpu_offload_group) > 1:
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook get_module_by_name_suffix(
model, cpu_offload_group[0][0]
)._hf_hook.prev_module_hook = prev_hook
for n, d in device_map.items(): for n, d in device_map.items():
m = get_module_by_name_suffix(model, n) m = get_module_by_name_suffix(model, n)
...@@ -43,33 +51,53 @@ def simple_dispatch_model(model, device_map): ...@@ -43,33 +51,53 @@ def simple_dispatch_model(model, device_map):
return model return model
def set_module_name(model, name, value): def set_module_name(model, name, value):
if '.' in name: if "." in name:
parent_name = name.rsplit('.', 1)[0] parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1:] child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name) parent = model.get_submodule(parent_name)
else: else:
parent_name = '' parent_name = ""
parent = model parent = model
child_name = name child_name = name
setattr(parent, child_name, value) setattr(parent, child_name, value)
def clear_memory(weight=None): def clear_memory(weight=None):
if weight is not None: if weight is not None:
del weight del weight
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def compute_memory_used_pct(device): def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 memory_pct = (
memory_used
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
* 100
)
return memory_pct return memory_pct
def get_best_device(): def get_best_device():
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
return 'mps' return "mps"
elif torch.cuda.is_available(): elif torch.cuda.is_available():
return 'cuda:0' return "cuda:0"
else: else:
return 'cpu' return "cpu"
\ No newline at end of file
def get_lowest_memory_device_index():
device = None
curr_device_memory_pct = 0
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
if device is None or device_memory_pct < curr_device_memory_pct:
device = device_index
curr_device_memory_pct = device_memory_pct
return device
# Examples
## Basic Quantization
AWQ performs zero point quantization down to a precision of 4-bit integers.
You can also specify other bit rates like 3-bit, but some of these options may lack kernels
for running inference.
Notes:
- Some models like Falcon is only compatible with group size 64.
- To use Marlin, you must specify zero point as False and version as Marlin.
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
quant_path = 'mistral-instruct-v0.2-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
```
### Custom Data
This includes an example function that loads either wikitext or dolly.
Note that currently all samples above 512 in length are discarded.
```python
from datasets import load_dataset
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Define data loading methods
def load_dolly():
data = load_dataset('databricks/databricks-dolly-15k', split="train")
# concatenate data
def concatenate_data(x):
return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']}
concatenated = data.map(concatenate_data)
return [text for text in concatenated["text"]]
def load_wikitext():
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")
return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20]
# Quantize
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext())
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
```
### GGUF Export
This computes AWQ scales and appliesthem to the model without running real quantization.
This keeps the quality of AWQ because theweights are applied but skips quantization
in order to make it compatible with other frameworks.
Step by step:
- `quantize()`: Compute AWQ scales and apply them
- `save_pretrained()`: Saves a non-quantized model in FP16
- `convert.py`: Convert the Huggingface FP16 weights to GGUF FP16 weights
- `quantize`: Run GGUF quantization to get real quantized weights, in this case 4-bit.
```python
import os
import subprocess
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'mistralai/Mistral-7B-v0.1'
quant_path = 'mistral-awq'
llama_cpp_path = '/workspace/llama.cpp'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" }
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
# NOTE: We avoid packing weights, so you cannot use this model in AutoAWQ
# after quantizing. The saved model is FP16 but has the AWQ scales applied.
model.quantize(
tokenizer,
quant_config=quant_config,
export_compatible=True
)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
# GGUF conversion
print('Converting model to GGUF...')
llama_cpp_method = "q4_K_M"
convert_cmd_path = os.path.join(llama_cpp_path, "convert.py")
quantize_cmd_path = os.path.join(llama_cpp_path, "quantize")
if not os.path.exists(llama_cpp_path):
cmd = f"git clone https://github.com/ggerganov/llama.cpp.git {llama_cpp_path} && cd {llama_cpp_path} && make LLAMA_CUBLAS=1 LLAMA_CUDA_F16=1"
subprocess.run([cmd], shell=True, check=True)
subprocess.run([
f"python {convert_cmd_path} {quant_path} --outfile {quant_path}/model.gguf"
], shell=True, check=True)
subprocess.run([
f"{quantize_cmd_path} {quant_path}/model.gguf {quant_path}/model_{llama_cpp_method}.gguf {llama_cpp_method}"
], shell=True, check=True)
```
## Basic Inference
To run inference, you often want to run with `fuse_layers=True` to get the claimed speedup in AutoAWQ.
Additionally, consider setting `max_seq_len` (default: 2048) as this will be the maximum context that the model can hold.
Notes:
- You can specify `use_exllama_v2=True` to enable ExLlamaV2 kernels during inference.
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
```
### Transformers
You can also load an AWQ model by using AutoModelForCausalLM, just make sure you have AutoAWQ installed.
Note that not all models will have fused modules when loading from transformers.
See more [documentation here](https://huggingface.co/docs/transformers/main/en/quantization#awq).
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
# NOTE: Must install from PR until merged
# pip install --upgrade git+https://github.com/younesbelkada/transformers.git@add-awq
model_id = "casperhansen/mistral-7b-instruct-v0.1-awq"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda:0"
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
text = "[INST] What are the basic steps to use the Huggingface transformers library? [/INST]"
tokens = tokenizer(
text,
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
```
### vLLM
You can also load AWQ models in [vLLM](https://github.com/vllm-project/vllm).
```python
import asyncio
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
model_path = "casperhansen/mixtral-instruct-awq"
# prompting
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?",
prompt_template = "[INST] {prompt} [/INST]"
# sampling params
sampling_params = SamplingParams(
repetition_penalty=1.1,
temperature=0.8,
max_tokens=512
)
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# async engine args for streaming
engine_args = AsyncEngineArgs(
model=model_path,
quantization="awq",
dtype="float16",
max_model_len=512,
enforce_eager=True,
disable_log_requests=True,
disable_log_stats=True,
)
async def generate(model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer):
tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids
outputs = model.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=1,
prompt_token_ids=tokens,
)
print("\n** Starting generation!\n")
last_index = 0
async for output in outputs:
print(output.outputs[0].text[last_index:], end="", flush=True)
last_index = len(output.outputs[0].text)
print("\n\n** Finished generation!\n")
if __name__ == '__main__':
model = AsyncLLMEngine.from_engine_args(engine_args)
asyncio.run(generate(model, tokenizer))
```
### LLaVa (multimodal)
AutoAWQ also supports the LLaVa model. You simply need to load an
AutoProcessor to process the prompt and image to generate inputs for the AWQ model.
```python
import requests
import torch
from PIL import Image
from awq import AutoAWQForCausalLM
from transformers import AutoProcessor
quant_path = "ybelkada/llava-1.5-7b-hf-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, safetensors=True, device_map={"": 0})
processor = AutoProcessor.from_pretrained(quant_path)
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Generate output
generation_output = model.generate(
**inputs,
max_new_tokens=512
)
print(processor.decode(generation_output[0], skip_special_tokens=True))
```
\ No newline at end of file
This diff is collapsed.
# Auto and Base model classes in AutoAWQ
View the documentation of the main classes of AutoAWQ models below.
::: awq.models.auto.AutoAWQForCausalLM
::: awq.models.base.BaseAWQForCausalLM
# AutoAWQ examples
Please see the docs for more thorough examples. In this folder, you will only find the
very basic examples of quantization, inference, and training.
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment