Unverified Commit c6c7b065 authored by Casper's avatar Casper Committed by GitHub
Browse files

Torch only inference + any-device quantization (#319)

parent 8117845b
import os import os
import gc import gc
import json import json
import time
import torch import torch
import transformers
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union from typing import List, Union
from safetensors.torch import save_file from safetensors.torch import save_file
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
import transformers
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.exllama import WQLinear_Exllama, exllama_post_init from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.utils.module import ( from awq.utils.module import (
get_named_linears, get_named_linears,
set_op_by_name, set_op_by_name,
...@@ -35,9 +34,6 @@ from accelerate.big_modeling import ( ...@@ -35,9 +34,6 @@ from accelerate.big_modeling import (
from awq.models._config import AwqConfig from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
from awq.quantize.quantizer import AwqQuantizer from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name from awq.utils.module import get_named_linears, set_op_by_name
......
import torch.nn as nn import torch.nn as nn
import awq_ext
import torch.nn.functional as F import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class QuantFusedMLP(nn.Module): class QuantFusedMLP(nn.Module):
def __init__( def __init__(
......
import torch import torch
from torch import nn from torch import nn
import awq_ext
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class FasterTransformerRMSNorm(nn.Module): class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6): def __init__(self, weight, eps=1e-6):
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from awq.utils.exllama_utils import unpack_reorder_pack from awq.utils.packing_utils import unpack_reorder_pack
import exl_ext # with CUDA kernels (AutoAWQ_kernels)
try:
import exl_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict from typing import Dict
from awq.utils.exllama_utils import unpack_reorder_pack from awq.utils.packing_utils import unpack_reorder_pack
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels) try:
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
......
import torch
import torch.nn as nn
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
awq_linear.scales = scales.clone().half()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
pack_num = 32 // awq_linear.w_bit
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
best_device = get_best_device()
# Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device
if "mps" in best_device:
intweight = intweight.to("cpu")
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32, device=best_device)
if "mps" in best_device:
zeros = zeros.to("cpu")
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
if AWQ_INSTALLED:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
else:
out = dequantize_gemm(
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size
)
out = torch.matmul(x, out)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
import torch import torch
import torch.nn as nn import torch.nn as nn
import awq_ext # with CUDA kernels try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
def make_divisible(c, divisor): def make_divisible(c, divisor):
...@@ -23,159 +27,6 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8): ...@@ -23,159 +27,6 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
return base_width return base_width
class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
awq_linear.scales = scales.clone().half()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
pack_num = 32 // awq_linear.w_bit
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
class WQLinear_GEMV(nn.Module): class WQLinear_GEMV(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__() super().__init__()
......
...@@ -6,10 +6,12 @@ import torch.nn as nn ...@@ -6,10 +6,12 @@ import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import Dict, List from typing import Dict, List
from collections import defaultdict from collections import defaultdict
from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip from awq.quantize.scale import apply_scale, apply_clip
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.utils.utils import clear_memory, get_best_device
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.utils.module import ( from awq.utils.module import (
append_str_prefix, append_str_prefix,
get_op_name, get_op_name,
...@@ -83,7 +85,12 @@ class AwqQuantizer: ...@@ -83,7 +85,12 @@ class AwqQuantizer:
# Move module and inputs to correct device # Move module and inputs to correct device
common_device = next(self.modules[i].parameters()).device common_device = next(self.modules[i].parameters()).device
if common_device is None or str(common_device) == "cpu": if common_device is None or str(common_device) == "cpu":
self.modules[i] = self.modules[i].cuda("cuda:" + str(i % torch.cuda.device_count())) if torch.cuda.is_available():
best_device = "cuda:" + str(i % torch.cuda.device_count())
else:
best_device = get_best_device()
self.modules[i] = self.modules[i].to(best_device)
common_device = next(self.modules[i].parameters()).device common_device = next(self.modules[i].parameters()).device
if self.module_kwargs.get("position_ids") is not None: if self.module_kwargs.get("position_ids") is not None:
...@@ -132,7 +139,7 @@ class AwqQuantizer: ...@@ -132,7 +139,7 @@ class AwqQuantizer:
def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
for name, linear_layer in named_linears.items(): for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float() # NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.cuda().half() linear_layer = linear_layer.to(get_best_device()).half()
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor( linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data, linear_layer.weight.data,
...@@ -274,7 +281,7 @@ class AwqQuantizer: ...@@ -274,7 +281,7 @@ class AwqQuantizer:
if any([_ in name for _ in avoid_clipping]): if any([_ in name for _ in avoid_clipping]):
continue continue
named_linears[name].cuda() named_linears[name].to(get_best_device())
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name]) max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
clip_list.append((name, max_val)) clip_list.append((name, max_val))
...@@ -343,8 +350,9 @@ class AwqQuantizer: ...@@ -343,8 +350,9 @@ class AwqQuantizer:
inps = [] inps = []
layer_kwargs = {} layer_kwargs = {}
modules[0] = modules[0].cuda() best_device = get_best_device()
self.awq_model.move_embed(self.model, "cuda") modules[0] = modules[0].to(best_device)
self.awq_model.move_embed(self.model, best_device)
# get input and kwargs to layer 0 # get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0 # with_kwargs is only supported in PyTorch 2.0
...@@ -390,7 +398,7 @@ class AwqQuantizer: ...@@ -390,7 +398,7 @@ class AwqQuantizer:
clear_memory() clear_memory()
if layer_kwargs.get("attention_mask") is not None: if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to("cuda") layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(best_device)
return modules, layer_kwargs, inps return modules, layer_kwargs, inps
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, List from typing import Tuple, List
from awq.utils.utils import get_best_device
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, set_op_by_name from awq.utils.module import get_op_by_name, set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.bloom.modeling_bloom import BloomGelu
...@@ -14,7 +15,7 @@ allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUA ...@@ -14,7 +15,7 @@ allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUA
def apply_clip(module, clip_list: Tuple[str, torch.Tensor]): def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
for name, max_val in clip_list: for name, max_val in clip_list:
layer: nn.Linear = get_op_by_name(module, name) layer: nn.Linear = get_op_by_name(module, name)
layer.cuda() layer.to(get_best_device())
max_val = max_val.to(layer.weight.device) max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
...@@ -28,10 +29,11 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -28,10 +29,11 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op = get_op_by_name(module, prev_op_name) prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names] layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda() best_device = get_best_device()
prev_op.to(best_device)
for layer in layers: for layer in layers:
layer.cuda() layer.to(best_device)
scales.cuda() scales.to(best_device)
if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear): if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear):
scale_fc_fcs(prev_op, layers, scales) scale_fc_fcs(prev_op, layers, scales)
......
import torch import torch
from awq.modules.exllama import WQLinear_Exllama from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.exllamav2 import WQLinear_ExllamaV2 from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear.exllama import WQLinear_Exllama
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
def prepare_correct_devices(next_layer, hidden_states, mask): def prepare_correct_devices(next_layer, hidden_states, mask):
hidden_states = hidden_states.to(next_layer.device) hidden_states = hidden_states.to(next_layer.device)
......
...@@ -78,3 +78,20 @@ def unpack_reorder_pack(qweight, qzeros, bits): ...@@ -78,3 +78,20 @@ def unpack_reorder_pack(qweight, qzeros, bits):
qweight, qzeros = pack_exllama(iweight, izeros, bits) qweight, qzeros = pack_exllama(iweight, izeros, bits)
return qweight, qzeros return qweight, qzeros
def dequantize_gemm(qweight, qzeros, scales, bits, group_size):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
# fp16 weights
scales = scales.repeat_interleave(group_size, dim=0)
izeros = izeros.repeat_interleave(group_size, dim=0)
iweight = (iweight - izeros) * scales
return iweight
\ No newline at end of file
...@@ -65,3 +65,11 @@ def compute_memory_used_pct(device): ...@@ -65,3 +65,11 @@ def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
return memory_pct return memory_pct
def get_best_device():
if torch.backends.mps.is_available():
return 'mps'
elif torch.cuda.is_available():
return 'cuda:0'
else:
return 'cpu'
\ No newline at end of file
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/zephyr-7B-beta-AWQ" quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ"
# Load model # Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
...@@ -9,12 +9,7 @@ tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) ...@@ -9,12 +9,7 @@ tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens # Convert prompt to tokens
prompt_template = """\ prompt_template = "[INST] {prompt} [/INST]"
<|system|>
</s>
<|user|>
{prompt}</s>
<|assistant|>"""
prompt = "You're standing on the surface of the Earth. "\ prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\ "You walk one mile south, one mile west and one mile north. "\
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant
"""
system = "You are a helpful assistant that answers precisely."
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(system=system, prompt=prompt),
return_tensors='pt'
).input_ids.to("mps")
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=64
)
\ No newline at end of file
import os import os
import sys import sys
import torch import torch
import platform
from pathlib import Path from pathlib import Path
from setuptools import setup, find_packages from setuptools import setup, find_packages
...@@ -8,8 +9,9 @@ os.environ["CC"] = "g++" ...@@ -8,8 +9,9 @@ os.environ["CC"] = "g++"
os.environ["CXX"] = "g++" os.environ["CXX"] = "g++"
AUTOAWQ_VERSION = "0.1.8" AUTOAWQ_VERSION = "0.1.8"
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1" PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
HAS_CUDA = torch.cuda.is_available()
if not PYPI_BUILD: if not PYPI_BUILD and HAS_CUDA:
try: try:
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3] CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3]
AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}" AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}"
...@@ -42,7 +44,6 @@ common_setup_kwargs = { ...@@ -42,7 +44,6 @@ common_setup_kwargs = {
} }
requirements = [ requirements = [
"autoawq-kernels",
"torch>=2.0.1", "torch>=2.0.1",
"transformers>=4.35.0", "transformers>=4.35.0",
"tokenizers>=0.12.1", "tokenizers>=0.12.1",
...@@ -50,6 +51,10 @@ requirements = [ ...@@ -50,6 +51,10 @@ requirements = [
"datasets", "datasets",
] ]
# CUDA kernels
if platform.system().lower() != "darwin" and HAS_CUDA:
requirements.append("autoawq-kernels")
setup( setup(
packages=find_packages(), packages=find_packages(),
install_requires=requirements, install_requires=requirements,
......
import torch
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
import awq_ext
from awq.utils.packing_utils import dequantize_gemm
in_features = 4096
out_features = 1792
w_bit = 4
group_size = 128
MAX_INT32 = 0x7fffffff
MIN_INT32 = -MAX_INT32 - 1
qweight = torch.randint(
MIN_INT32,
MAX_INT32,
(in_features, out_features // (32 // w_bit)),
dtype=torch.int32,
device="cuda",
)
qzeros = torch.randint(
MIN_INT32,
MAX_INT32,
(in_features // group_size, out_features // (32 // w_bit)),
dtype=torch.int32,
device="cuda",
)
scales = torch.randn(
(in_features // group_size, out_features),
dtype=torch.float16,
device="cuda",
)
with torch.no_grad():
cuda_out = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
0,
0,
0,
False
)
torch_out = dequantize_gemm(
qweight,
qzeros,
scales,
w_bit,
group_size
)
assert(torch.allclose(cuda_out, torch_out, rtol=0.0001))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment