Commit da900c3b authored by yangql's avatar yangql
Browse files

Initial commit

parents
from abc import abstractmethod
from logging import getLogger
import torch.nn as nn
from .triton_utils.mixin import TritonModuleMixin
logger = getLogger(__name__)
class FusedBaseModule(nn.Module, TritonModuleMixin):
@classmethod
@abstractmethod
def inject_to_model(cls, *args, **kwargs):
raise NotImplementedError()
class FusedBaseAttentionModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(
cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, trainable=False, **kwargs
):
raise NotImplementedError()
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
pass
class FusedBaseMLPModule(FusedBaseModule):
@classmethod
@abstractmethod
def inject_to_model(cls, model, use_triton=False, **kwargs):
raise NotImplementedError()
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers.models.gptj.modeling_gptj import GPTJAttention
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from ._fused_base import FusedBaseAttentionModule
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)
class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
def __init__(self, config):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.attn_dropout_p = config.attn_pdrop
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
def _split_heads(self, qkv):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = qkv.size()[:-1] + (3, self.num_attention_heads, self.head_dim)
qkv = qkv.view(new_shape) # (batch, seq_length, 3, head, head_features)
query = qkv[:, :, 0]
key = qkv[:, :, 1]
value = qkv[:, :, 2]
return query, key, value
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query, key, value = self._split_heads(self.qkv_proj(hidden_states))
seq_len = key.shape[1]
offset = 0
if layer_past is not None:
offset = layer_past[0].shape[-2]
seq_len += offset
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
key = apply_rotary_pos_emb(key, sincos, offset=offset)
query = apply_rotary_pos_emb(query, sincos, offset=offset)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
is_causal = layer_past is None
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
present = (key, value)
else:
present = None
# compute self-attention: V x Softmax(QK^T)
if compare_pytorch_version("v2.0.0", op="ge"):
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None if is_causal else attention_mask,
dropout_p=self.attn_dropout_p,
is_causal=is_causal,
)
attn_weights = None
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
@classmethod
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=True,
disable_exllamav2=False,
**kwargs,
):
config = model.config
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
)
for name, m in model.named_modules():
if not isinstance(m, GPTJAttention):
continue
attn = cls(config).to(device=next(m.buffers()).device)
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# See fused_llama_attn.py comment
raise ValueError(
"Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True."
)
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = (
q_proj.bits,
q_proj.group_size,
q_proj.infeatures,
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False,
)
qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qlinear_kwargs["weight_dtype"] = q_proj.scales.dtype
qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs)
qkv_proj.qweight = qweights
qkv_proj.qzeros = qzeros
qkv_proj.scales = scales
qkv_proj.g_idx = g_idx
qkv_proj.bias = bias
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
attn.qkv_proj = qkv_proj
attn.out_proj = m.out_proj
setattr(parent, child_name, attn)
del m
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import (
LlamaAttention,
apply_rotary_pos_emb,
)
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
from ._fused_base import FusedBaseAttentionModule
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size,
num_heads,
qkv_proj,
o_proj,
rotary_emb,
layer_idx,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.layer_idx = layer_idx
if self.head_dim * num_heads != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = rotary_emb
def _shape(self, tensor, seq_len, bsz):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states,
past_key_value=None,
attention_mask=None,
position_ids=None,
output_attentions=False,
use_cache=False,
**kwargs,
):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index. Please open an issue in AutoGPTQ if you hit this."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
if compare_pytorch_version("v2.0.0", op="ge"):
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None and q_len > 1,
)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
@classmethod
def inject_to_model(
cls,
model,
use_triton=False,
group_size=-1,
use_cuda_fp16=True,
desc_act=False,
trainable=False,
bits: int = 4,
disable_exllama=True,
disable_exllamav2=False,
**kwargs,
):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
)
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if QuantLinear.QUANT_TYPE == "exllama":
if desc_act:
# TODO: support it. The issue lies maybe in the line:
# int groups = qzeros.size(0);
# in exllama_ext.cpp
raise ValueError(
"Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True."
)
else:
g_idx = None
else:
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qlinear_args = (
q_proj.bits,
q_proj.group_size,
q_proj.infeatures,
q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures,
True if q_proj.bias is not None else False,
)
qlinear_kwargs = {"trainable": trainable}
if (not desc_act or group_size == -1) and not use_triton:
qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16
qlinear_kwargs["weight_dtype"] = q_proj.scales.dtype
qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias
# Introduced in Transformers 4.36
layer_idx = None
if hasattr(m, "layer_idx"):
layer_idx = m.layer_idx
attn = cls(
m.hidden_size,
m.num_heads,
qkv_layer,
m.o_proj,
m.rotary_emb,
layer_idx=layer_idx,
)
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, attn)
__all__ = ["FusedLlamaAttentionForQuantizedModel"]
import math
from logging import getLogger
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from ..utils.import_utils import TRITON_AVAILABLE
from ._fused_base import FusedBaseMLPModule
logger = getLogger(__name__)
if TRITON_AVAILABLE:
import triton
import triton.language as tl
from .triton_utils import custom_autotune
from .triton_utils.kernels import silu
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
), # 3090
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
), # 3090
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
), # 3090
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
), # 3090
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
), # 3090
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def quant_fused_matmul_248_kernel(
a_ptr,
c_ptr,
b1_ptr,
scales1_ptr,
zeros1_ptr,
g1_ptr,
b2_ptr,
scales2_ptr,
zeros2_ptr,
g2_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Computes: C = silu(A * B1) * (A * B2)
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (1, N) float16
zeros is of shape (1, N//8) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
g1_ptrs = g1_ptr + offs_k
g2_ptrs = g2_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales1_ptrs = scales1_ptr + offs_bn[None, :]
scales2_ptrs = scales2_ptr + offs_bn[None, :]
zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g1_idx = tl.load(g1_ptrs)
g2_idx = tl.load(g2_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
zeros1 = zeros1 + 1
zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
zeros2 = zeros2 + 1
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
b2 = tl.load(b2_ptrs)
# Now we need to unpack b (which is N-bit values) into 32-bit values
b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
b1 = (b1 - zeros1) * scales1 # Scale and shift
accumulator1 += tl.dot(a, b1)
b2 = (b2 >> shifter[:, None]) & maxq
b2 = (b2 - zeros2) * scales2
accumulator2 += tl.dot(a, b2)
a_ptrs += BLOCK_SIZE_K
b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g1_ptrs += BLOCK_SIZE_K
g2_ptrs += BLOCK_SIZE_K
accumulator1 = silu(accumulator1)
c = accumulator1 * accumulator2
c = c.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
else:
quant_fused_matmul_248_kernel = None
class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule):
def __init__(
self,
gate_proj,
down_proj,
up_proj,
):
super().__init__()
self.infeatures = gate_proj.infeatures
self.intermediate_size = gate_proj.outfeatures
self.outfeatures = down_proj.outfeatures
self.bits = gate_proj.bits
self.maxq = gate_proj.maxq
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj
def forward(self, x):
return self.down_proj(self.triton_llama_mlp(x))
def triton_llama_mlp(self, x):
with torch.cuda.device(x.device):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
M, K = x.shape
N = self.intermediate_size
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
quant_fused_matmul_248_kernel[grid](
x,
c,
self.gate_proj.qweight,
self.gate_proj.scales,
self.gate_proj.qzeros,
self.gate_proj.g_idx,
self.up_proj.qweight,
self.up_proj.scales,
self.up_proj.qzeros,
self.up_proj.g_idx,
M,
N,
K,
self.bits,
self.maxq,
x.stride(0),
x.stride(1),
self.gate_proj.qweight.stride(0),
self.gate_proj.qweight.stride(1),
c.stride(0),
c.stride(1),
self.gate_proj.scales.stride(0),
self.gate_proj.qzeros.stride(0),
)
c = c.reshape(out_shape)
return c
@classmethod
def inject_to_model(cls, model, use_triton=False, **kwargs):
if not use_triton:
logger.warning(
f"Skipping module injection for {cls.__name__} as currently not supported with use_triton=False."
)
return
elif not TRITON_AVAILABLE:
logger.warning(
f"Skipping module injection for {cls.__name__} as Triton is not available. Please check your installation."
)
return
for name, m in model.named_modules():
if not isinstance(m, LlamaMLP):
continue
mlp = cls(m.gate_proj, m.down_proj, m.up_proj)
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, mlp)
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, cls):
continue
k = m.infeatures
n = m.intermediate_size
if (k, n) not in kn_values:
kn_values[(k, n)] = m
logger.info(f"Found {len(kn_values)} unique fused mlp KN values.")
logger.info("Warming up autotune cache ...")
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2**m
for (k, n), (modules) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
modules.triton_llama_mlp(a)
del kn_values
__all__ = ["FusedLlamaMLPForQuantizedModel"]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment