Commit 6dba61b8 authored by Casper's avatar Casper
Browse files

Merge branch 'main' into pr/53

parents 8e7059a7 a5e8b048
...@@ -35,12 +35,13 @@ class AutoAWQForCausalLM: ...@@ -35,12 +35,13 @@ class AutoAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True, device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1) -> BaseAWQForCausalLM: batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors
) )
\ No newline at end of file
...@@ -6,17 +6,21 @@ import logging ...@@ -6,17 +6,21 @@ import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union
from collections import defaultdict from collections import defaultdict
from safetensors.torch import save_file
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.utils import simple_dispatch_model
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from transformers.modeling_utils import shard_checkpoint
from awq.quantize.quantizer import pseudo_quantize_tensor from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
...@@ -41,13 +45,17 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -41,13 +45,17 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True, auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"): calib_data: Union[str, List[str]]="pileval", split="train",
text_column="text"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search: if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen, self.search_result = self._awq_search(
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data) tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
split=split, text_column=text_column
)
if run_quant: if run_quant:
self._awq_quant() self._awq_quant()
...@@ -103,11 +111,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -103,11 +111,14 @@ class BaseAWQForCausalLM(nn.Module):
gc.collect() gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512, def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"): auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
split="train", text_column="text"):
layers = self.get_model_layers(self.model) layers = self.get_model_layers(self.model)
samples = get_calib_dataset( samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen) data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen,
split=split, text_column=text_column
)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
inps = [] inps = []
...@@ -214,20 +225,43 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -214,20 +225,43 @@ class BaseAWQForCausalLM(nn.Module):
return awq_results return awq_results
def save_quantized(self, save_dir): def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
def _save_files(save_dir, model_name, model): def _save_files(save_dir, model_name='', search_result=None):
class EmptyModule(nn.Module): class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__() def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x def forward(self, x): return x
# Save model fiels without search results # Save model files with empty state dict
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty module # Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin') os.remove(f'{save_dir}/pytorch_model.bin')
# Save search results if search_result is not None:
torch.save(model, f'{save_dir}/{model_name}') torch.save(search_result, f'{save_dir}/{model_name}')
else:
# model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
self.model.state_dict(),
max_shard_size=shard_size,
weights_name=model_name
)
for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))
# Save config # Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file: with open(f'{save_dir}/quant_config.json', 'w+') as file:
...@@ -237,8 +271,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -237,8 +271,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model # Save model
if self.search_result is None or self.is_quantized: if self.search_result is None or self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt' _save_files(save_dir, '', search_result=None)
_save_files(save_dir, model_name, self.model.state_dict())
else: else:
model_name = 'awq_model_search_result.pt' model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result) _save_files(save_dir, model_name, self.search_result)
...@@ -259,21 +292,24 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -259,21 +292,24 @@ class BaseAWQForCausalLM(nn.Module):
) )
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None, def from_quantized(self, model_path, model_type, model_filename='',
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, max_new_tokens=None, device='balanced', torch_dtype=torch.float16,
safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'): trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors: if safetensors:
ignore_patterns.extend(["*.pt", "*.bin"]) ignore_patterns.extend(["*.pt*", "*.bin*"])
else: else:
ignore_patterns.append("*safetensors*") ignore_patterns.append("*.safetensors*")
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
# TODO: Better naming, model_filename becomes a directory if model_filename != '':
model_filename = model_path + f'/{model_filename}' model_weights_path = model_path + f'/{model_filename}'
else:
model_weights_path = model_path
# [STEP 2] Load config and set sequence length # [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class # TODO: Create BaseAWQConfig class
...@@ -316,13 +352,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -316,13 +352,14 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
if is_quantized: if is_quantized:
model = load_checkpoint_and_dispatch( load_checkpoint_in_model(
model, model,
model_filename, checkpoint=model_weights_path,
device_map=device_map, device_map=device_map
no_split_module_classes=[self.layer_type]
) )
model = simple_dispatch_model(model, device_map)
if fuse_layers: if fuse_layers:
self.fuse_layers(model, quant_config) self.fuse_layers(model, quant_config)
...@@ -332,7 +369,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -332,7 +369,7 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_filename, model_weights_path,
device_map=device_map, device_map=device_map,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
offload_folder="offload", offload_folder="offload",
......
...@@ -7,7 +7,10 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -7,7 +7,10 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict): def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model) fuser = FalconFuser(model)
fuser.fuse_transformer()
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
if model.config.num_attention_heads == 71:
fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: FalconForCausalLM): def get_model_layers(model: FalconForCausalLM):
......
...@@ -100,6 +100,7 @@ class LlamaFuser: ...@@ -100,6 +100,7 @@ class LlamaFuser:
attn = QuantAttentionFused( attn = QuantAttentionFused(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
next(iter(qkv_layer.state_dict().values())).device, next(iter(qkv_layer.state_dict().values())).device,
......
...@@ -62,89 +62,61 @@ def build_alibi_bias( ...@@ -62,89 +62,61 @@ def build_alibi_bias(
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
positions: torch.Tensor,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
positions,
query,
key,
self.dim,
self.cos_sin_cache
)
return query, key
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None): use_alibi=False, attention_shapes=None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_local_heads = num_heads self.n_heads = n_heads
self.head_dim = self.hidden_size // num_heads self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
self.head_dim = self.hidden_size // n_heads
self.qkv_proj = qkv_layer self.qkv_proj = qkv_layer
self.o_proj = o_proj self.o_proj = o_proj
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition if attention_shapes is not None:
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,), self.attention_shapes = attention_shapes
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,), elif self.n_kv_heads == 0:
"xqkv_view": (-1, self.n_local_heads, self.head_dim), self.attention_shapes = {
"xq_slice": lambda xqkv: xqkv[:, :, 0], # following fastertransformer definition
"xk_slice": lambda xqkv: xqkv[:, :, 1], "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
"xv_slice": lambda xqkv: xqkv[:, :, 2], # 8: pack 8 fp16 in FT, if fp32 then use 4
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8), "cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xq_view": (self.n_local_heads, self.head_dim), "xqkv_view": (-1, self.n_heads, self.head_dim),
"xk_view": (self.n_local_heads, self.head_dim), "xq_slice": lambda xqkv: xqkv[:, :, 0],
"xv_view": (self.n_local_heads, self.head_dim), "xk_slice": lambda xqkv: xqkv[:, :, 1],
"single_xq_view": (self.n_local_heads, self.head_dim), "xv_slice": lambda xqkv: xqkv[:, :, 2],
"single_xk_view": (self.n_local_heads, self.head_dim), "xq_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim) "xk_view": (self.n_heads, self.head_dim),
} "xv_view": (self.n_heads, self.head_dim),
"xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_heads, self.head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
self.cache_v = ( self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
...@@ -155,14 +127,14 @@ class QuantAttentionFused(nn.Module): ...@@ -155,14 +127,14 @@ class QuantAttentionFused(nn.Module):
) )
if use_alibi: if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len) alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev) self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev) self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0 self.rotary_dim = 0
self.is_neox = False self.is_neox = False
else: else:
self.freqs_cis = precompute_freqs_cis( self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads, hidden_size // n_heads,
max_seq_len * 2, max_seq_len * 2,
).to(dev) ).to(dev)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
...@@ -213,6 +185,11 @@ class QuantAttentionFused(nn.Module): ...@@ -213,6 +185,11 @@ class QuantAttentionFused(nn.Module):
xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous() xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
keys = xk keys = xk
values = xv values = xv
if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2) xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2) keys = keys.transpose(1, 2)
......
...@@ -6,9 +6,13 @@ class MPTBlock(nn.Module): ...@@ -6,9 +6,13 @@ class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 0
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.norm_1 = norm_1 self.norm_1 = norm_1
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev) self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=True
).to(dev)
self.norm_2 = norm_2 self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev) self.ffn = mpt_mlp.to(dev)
...@@ -30,16 +34,22 @@ class MPTBlock(nn.Module): ...@@ -30,16 +34,22 @@ class MPTBlock(nn.Module):
return out, None, past_key_value return out, None, past_key_value
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len,
input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 8
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
if new_decoder_arch:
attention_shapes = None
else:
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
# TODO: Falcon has ALiBi implemented but which model uses it? # TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
hidden_size, self.n_heads, qkv_layer, o_proj, hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes attention_shapes=attention_shapes
).to(dev) ).to(dev)
...@@ -52,47 +62,26 @@ class FalconDecoderLayer(nn.Module): ...@@ -52,47 +62,26 @@ class FalconDecoderLayer(nn.Module):
self.mlp = mlp self.mlp = mlp
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch): def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch: self.attention_shapes = {
kv_heads = 8 # following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
self.attention_shapes = { # 8: pack 8 fp16 in FT, if fp32 then use 4
# following fastertransformer definition "cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"cache_v": (batch_size, n_heads+(kv_heads*2), max_seq_len, head_dim,), "xqkv_view": (n_heads+2, head_dim),
# 8: pack 8 fp16 in FT, if fp32 then use 4 "xq_slice": lambda xqkv: xqkv[:, :, :-2],
"cache_k": (batch_size, n_heads+(kv_heads*2), head_dim // 8, max_seq_len, 8,), "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xqkv_view": (-1, n_heads+(kv_heads*2), head_dim), "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xq_slice": lambda xqkv: xqkv[:, :, :,0], "xq_view": (n_heads, head_dim),
"xk_slice": lambda xqkv: xqkv[:, :, :,1], "xk_view": (1, head_dim),
"xv_slice": lambda xqkv: xqkv[:, :, :,2], "xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8), "xk_reshape": (1, head_dim // 8, 8),
"xq_view": (1, head_dim), "single_xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim), "single_xk_view": (1, head_dim),
"xv_view": (1, head_dim), "single_xv_view": (1, head_dim)
"single_xq_view": (n_heads, head_dim), }
"single_xk_view": (1, 8, head_dim),
"single_xv_view": (1, 8, head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
}
return self.attention_shapes return self.attention_shapes
......
...@@ -194,7 +194,13 @@ class WQLinear_GEMV(nn.Module): ...@@ -194,7 +194,13 @@ class WQLinear_GEMV(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size) inputs = x.reshape(-1, x.shape[-1])
if inputs.shape[0] > 8:
out = awq_inference_engine.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
else:
out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)
......
import torch import torch
import logging import logging
from typing import List, Union
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): def get_calib_dataset(data: Union[str, List[str]] = "pileval",
if data == "pileval": tokenizer=None, n_samples=512, block_size=512,
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") split="train", text_column="text"):
if isinstance(data, str):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list):
dataset = [{text_column: text} for text in data]
else: else:
raise NotImplementedError raise NotImplementedError(
dataset = dataset.shuffle(seed=42) "Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element.")
samples = [] samples = []
n_run = 0 n_run = 0
for data in dataset: for data in dataset:
line = data["text"] line = data[text_column]
line = line.strip() line = line.strip()
line_encoded = tokenizer.encode(line) line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512: if len(line_encoded) > 512:
......
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
}
\ No newline at end of file
...@@ -2,3 +2,6 @@ ...@@ -2,3 +2,6 @@
torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters);
torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters);
\ No newline at end of file
...@@ -25,6 +25,10 @@ __pack_half2(const half x, const half y) { ...@@ -25,6 +25,10 @@ __pack_half2(const half x, const half y) {
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{ {
static constexpr uint32_t ZERO = 0x0; static constexpr uint32_t ZERO = 0x0;
...@@ -412,6 +416,274 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in ...@@ -412,6 +416,274 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
} }
} }
template <int G>
__global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[64];
__shared__ half A_shared[128 * (32 + 8)];
__shared__ half B_shared[64 * (32 + 8)];
// __shared__ half scaling_factors_shared[64];
// __shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1);
half A_shared_warp[32];
half B_shared_warp[16];
for (int i_0_3_init = 0; i_0_3_init < 4; ++i_0_3_init) {
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[((i_0_3_init * 16) + (j_0_4_init * 8)) + i] = 0.0;
}
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride_A = 4 * 32 * 8 / 32;
static constexpr int row_stride = 4 * 32 * 8 / 32;
const int make_divisible_multipler = 128 / G;
const int zeros_w = make_divisible(make_divisible(IC / G, 8), make_divisible_multipler) * make_divisible_multipler;
const int sf_w = zeros_w * 8;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
int ld_A_row = (blockIdx_y / j_factors1 * 128 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 128 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (IC / 8) * 8
+ (((int)threadIdx.x) / (32 / 8)) * (IC / 8)
+ (((int)blockIdx_y) % j_factors1) * 64 * (IC / 8)
+ (((int)threadIdx.x) % (32 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 4) * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* zeros_ptr = zeros
+ ((int)threadIdx.y) * zeros_w * 8
+ (((int)threadIdx.x) / (32 / 8)) * zeros_w
+ (((int)blockIdx_y) % j_factors1) * 64 * zeros_w
// this term is zero
+ (((int)threadIdx.x) % (32 / 8)) / G ;
half* scaling_factors_ptr = scaling_factors
+ ((int)threadIdx.y) * sf_w * 8
+ (((int)threadIdx.x) / (32 / 8)) * sf_w
+ (((int)blockIdx_y) % j_factors1) * (64) * sf_w
// this term is zero
+ (((int)threadIdx.x) % (32 / 8)) * 8 / G;
// Haotian: TBD, check, May 29 11:46 AM PST
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdx_z -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ (((int)threadIdx.y) / 2) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = make_divisible(IC / 32, split_k_iters); // (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1;
// TODO (Haotian): load scales and zero points to smem
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: Here we assume M % cta_M = 0.
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0)
{
if (ld_A_row + ax0_ax1_fused_0 * row_stride_A < M)
{
*(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = *(uint4*)(A_ptr + (ax0_ax1_fused_0 * row_stride_A * IC) + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = make_uint4(0, 0, 0, 0);
}
}
int* zeros_ptr_local = zeros_ptr + k_0_0 * 32 / G / 8;
half* scaling_factors_ptr_local = scaling_factors_ptr + k_0_0 * 32 / G;
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * (32 / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
int B_loaded_current = *(B_ptr_local + ax0_ax1_fused_0 * row_stride * (IC / 8));
int zeros_loaded = *(zeros_ptr_local + ax0_ax1_fused_0 * row_stride * zeros_w);
zeros_loaded >>= ((k_0_0 * 32 / G) % 8) * 4;
float current_zeros = (float)(zeros_loaded & 0xF);
half scaling_factors_loaded = *(scaling_factors_ptr_local + ax0_ax1_fused_0 * row_stride * sf_w);
half B_loaded_fp16[8];
#pragma unroll
for (int ic_1 = 0; ic_1 < 8; ic_1++){
float current_single_weight_fp = (float)(B_loaded_current & 0xF);
half dequantized_weight = __float2half(__half2float(scaling_factors_loaded) * (current_single_weight_fp - current_zeros));
B_loaded_current = B_loaded_current >> 4;
B_loaded_fp16[ic_1] = dequantized_weight;
}
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (32 + 8)) = *reinterpret_cast<uint4*>(B_loaded_fp16);
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) {
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (k_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) {
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[((((((int)threadIdx.y) >> 1) * 1280) + (ax0_0_1 * 640)) + (k_0_1 * 16))])) + ((((((int)threadIdx.x) >> 4) * 320) + ((((int)threadIdx.x) & 7) * 40)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[3])
: "r"(addr)
);
}
}
for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) {
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]));
}
}
}
}
}
// Haotian: Here (May 29 11:46AM PST)
// TODO: Shang: Hoist loop invariance.
for (int ax0_0_2 = 0; ax0_0_2 < 4; ++ax0_0_2) {
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 128 + (threadIdx.y % 2) * 64 + ax0_0_2 * 16 + (local_id % 4) / 2 * 8 + ((int)threadIdx.x) / 4;
if (row_offset < M)
{
*(C_ptr + ax1_0 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax0_0_2 * 16) + (ax1_0 * 8) + local_id]);
}
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor gemmv2_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
// for int4, need _kernel.size(1) * 8
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(0)}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
// blockIdx_x: i_factors[0] * j_factors[0]
// blockIdx_y: i_factors[1] * j_factors[1]
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks((num_out_feats + 128 - 1) / 128 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 4);
if (group_size == 128)
{
gemmv2_forward_4bit_cuda_m128n64k32<128><<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (group_size == 64)
{
gemmv2_forward_4bit_cuda_m128n64k32<64><<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else
{
throw std::invalid_argument("Group size temporarily not supported.");
}
return _out_feats.sum(0);
}
// in_feats: M, IC [float16] // in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16] // scaling_factors: IC // G, OC [float16]
......
...@@ -6,6 +6,7 @@ quant_path = 'vicuna-7b-v1.5-awq' ...@@ -6,6 +6,7 @@ quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model # Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(model_path) model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
...@@ -13,6 +14,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) ...@@ -13,6 +14,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.quantize(tokenizer, quant_config=quant_config) model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model # Save quantized model
# NOTE: pass safetensors=True to save quantized model weights as safetensors
model.save_quantized(quant_path) model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path) tokenizer.save_pretrained(quant_path)
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/opt-125m-awq"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
USER: {prompt}
ASSISTANT:"""
tokens = tokenizer(
prompt_template.format(prompt="How are you today?"),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
from datasets import load_dataset
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Define data loading methods
def load_dolly():
data = load_dataset('databricks/databricks-dolly-15k', split="train")
# concatenate data
def concatenate_data(x):
return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']}
concatenated = data.map(concatenate_data)
return [text for text in concatenated["text"]]
def load_wikitext():
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")
return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20]
# Quantize
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext())
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment