Commit fcd9637c authored by gaoqiong's avatar gaoqiong
Browse files

Merge branch 'v0.2.5_develop' into 'main'

v0.2.5

See merge request dcutoolkit/deeplearing/autoawq!2
parents 7724cca1 427f5481
import torch
class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device):
"""
The window size is the same as the max_seq_len. The window will
automatically roll once max_seq_len is exceeded.
"""
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half()
self.max_seq_len = max_seq_len
def get_kv(self, batch_size, start_pos, seqlen, head_dim):
"""
Gets the key-value store in correct shapes.
"""
xv = (
self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
)
xk = (
self.k[:batch_size, :, :, : start_pos + seqlen, :]
.transpose(2, 3)
.contiguous()
)
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
"""
Updates the values in the key-value store.
"""
self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store
def roll_kv_n_steps(self, start_pos, n=100):
"""
Roll cache n to the left.
"""
n = min(n, self.max_seq_len)
# Roll cache to the left
self.v = torch.roll(self.v, shifts=-n, dims=2)
self.k = torch.roll(self.k, shifts=-n, dims=3)
# Zero out the new part
self.v[:, :, -n:, :] = 0
self.k[:, :, :, -n:, :] = 0
return start_pos - n
def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)
def increase_batch_size(self, to_bsz):
"""Dynamically allocate new kv when batch size changes."""
self.v = torch.zeros(
to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device
)
self.k = torch.zeros(
to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device
)
def decrease_batch_size(self, to_bsz):
"""Dynamically remove part of cache if batch size changes."""
self.v = self.v[:to_bsz, :, :, :]
self.k = self.k[:to_bsz, :, :, :, :]
import torch.nn as nn
import torch.nn.functional as F
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class QuantFusedMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj,
activation=F.silu,
):
super().__init__()
self.register_buffer("gate_proj_qweight", gate_proj.qweight)
self.register_buffer("gate_proj_scales", gate_proj.scales)
self.register_buffer("gate_proj_qzeros", gate_proj.qzeros)
self.register_buffer("up_proj_qweight", up_proj.qweight)
self.register_buffer("up_proj_scales", up_proj.scales)
self.register_buffer("up_proj_qzeros", up_proj.qzeros)
self.in_features = gate_proj.in_features
self.intermediate_size = gate_proj.out_features
self.out_features = down_proj.out_features
self.w_bit = gate_proj.w_bit
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_ext.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_ext.gemm_forward_cuda
self.group_size = 8
self.activation = activation
def forward(self, x, routing_weights=None):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = self.linear(
x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.group_size,
)
up_output = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.group_size,
)
x = self.activation(gate_output) * up_output
x = x.reshape(out_shape)
x = self.down_proj(x)
if routing_weights is not None:
x = routing_weights * x
return x
class QuantLlamaMLP(QuantFusedMLP):
r"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead.
"""
def __init__(self, gate_proj, down_proj, up_proj):
super().__init__(gate_proj, down_proj, up_proj)
This diff is collapsed.
This diff is collapsed.
import torch
from torch import nn
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
super().__init__()
self.weight = weight
self.variance_epsilon = eps
def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
output = torch.empty_like(x)
awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output
from .exllama import WQLinear_Exllama, exllama_post_init
from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin, marlin_post_init
from .gemv_fast import WQLinear_GEMVFast
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
import logging
from typing import List, Union
from datasets import load_dataset
def get_calib_dataset(
data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None,
n_samples=512,
block_size=512,
split="train",
text_column="text",
):
if isinstance(data, str):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list):
if isinstance(data[0], str):
dataset = [{text_column: text} for text in data]
elif isinstance(data[0][0], int):
dataset = data
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
samples = []
n_run = 0
for data in dataset:
if isinstance(data, list):
line_encoded = data
else:
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
continue
samples.append(sample)
n_run += 1
if n_run == n_samples:
break
# now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
logging.debug(f" * Split into {n_split} blocks")
return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split)
]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment