Unverified Commit 9b76cab1 authored by ZiWei Yuan's avatar ZiWei Yuan Committed by GitHub
Browse files

Merge pull request #898 from kvcache-ai/develop-0.2.3post2

Release 0.2.3post2
parents dfe09b05 b5ef7c26
...@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM ...@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
...@@ -169,7 +170,7 @@ def local_chat( ...@@ -169,7 +170,7 @@ def local_chat(
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
......
...@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability ...@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import logging import logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from flash_attn import flash_attn_func from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
try:
from flash_attn import flash_attn_func
except:
pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
import os import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled: if flashinfer_enabled:
...@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
attn_output = flash_attn_func( # for bsz = 1
query_states, attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
key_states, b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
value_states_padded, b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
softmax_scale=self.softmax_scale,
causal=True, max_input_len = q_len
context_attention_fwd(
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
o=attn_output,
b_start_loc=b_start_loc,
b_seq_len=b_seq_len,
max_input_len=max_input_len,
is_causal=True
) )
if self.q_head_dim != self.v_head_dim: if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output[:, :, : self.v_head_dim]
attn_output = attn_output.reshape( attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim bsz, q_len, self.num_heads * self.v_head_dim
...@@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt' or get_compute_capability()<8: if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
print("for Windows or GPU before ampere, use forward_windows")
return self.forward_windows( return self.forward_windows(
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -17,7 +17,10 @@ import logging ...@@ -17,7 +17,10 @@ import logging
logger = logging.getLogger("dynamic_attention") logger = logging.getLogger("dynamic_attention")
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend") sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
from flash_attn import flash_attn_func, flash_attn_with_kvcache try:
from flash_attn import flash_attn_func, flash_attn_with_kvcache
except:
print("falsh attn not found")
import math import math
......
...@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext ...@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import cpuinfer_ext import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from typing import Dict, Tuple, Optional, Union
import numpy as np
#class KLinearBase(BaseInjectedModule, ABC): #class KLinearBase(BaseInjectedModule, ABC):
class KLinearBase(ABC): class KLinearBase(ABC):
...@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase): ...@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
if self.has_bias: if self.has_bias:
self.bias = None self.bias = None
class KLinearQ8(KLinearBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.compute_dtype = torch.float32
self.weight = None
self.weight_scale = None
self.weight_zero_point = None
self.bias = None
self.loaded = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
out_device = x.device
x = x.to(device=self.device, dtype=self.compute_dtype)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)
out = x @ weight_dequant.T
if self.has_bias:
out = out + self.bias
return out.to(dtype=orig_dtype, device=out_device)
def _dequantize_weight(self, q_matrix, scales, bits=8):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if not isinstance(q_matrix, torch.Tensor):
q_matrix = torch.tensor(q_matrix, dtype=torch.int8)
if not isinstance(scales, torch.Tensor):
scales = torch.tensor(scales, dtype=torch.float32)
# Convert to correct dtype if needed
if q_matrix.dtype != torch.int8:
q_matrix = q_matrix.to(torch.int8)
if scales.dtype != torch.float32:
scales = scales.to(torch.float32)
# For Q4, ensure the values stay within 4-bit range
if bits == 4:
q_matrix = torch.clamp(q_matrix, -7, 7)
rows, cols = q_matrix.shape
dequant_matrix = q_matrix.to(torch.float32)
scales_broadcast = scales.view(1, cols)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix = dequant_matrix * scales_broadcast
return dequant_matrix
def _quantize_weight(self, matrix, bits=8):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if not isinstance(matrix, torch.Tensor):
matrix = torch.tensor(matrix, dtype=torch.float32)
# Convert to float32 if needed
if matrix.dtype != torch.float32:
matrix = matrix.to(torch.float32)
# Get matrix shape
rows, cols = matrix.shape
# Determine quantization parameters based on bits
if bits == 8:
max_int = 127
qtype = torch.int8
elif bits == 4:
max_int = 7
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
else:
raise ValueError("Quantization bits must be either 8 or 4")
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
# Calculate max absolute value for each column
max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)
# Handle zero columns (avoid division by zero)
zero_cols = max_abs_vals == 0
max_abs_vals[zero_cols] = 1.0
# Calculate scale factors for all columns at once
scales = max_abs_vals / max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast = scales.view(1, cols)
# Apply quantization to the entire matrix at once
q_matrix = torch.round(matrix / scales_broadcast).to(qtype)
# For Q4, clamp values to ensure they stay within 4-bit range
if bits == 4:
q_matrix = torch.clamp(q_matrix, -max_int, max_int)
return q_matrix, scales
def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):
if self.loaded: return
if device is None: device = self.device
if w is None: w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
try:
weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
except:
weight = w.to(dtype=self.compute_dtype)
self.has_bias = False
elif isinstance(w, tuple):
try:
weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
except:
weight = w[0].to(dtype=self.compute_dtype)
self.bias = w[1].to(dtype=self.compute_dtype).to(device)
self.has_bias = True
else:
raise ValueError("Invalid weight type")
self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)
self.weight = self.weight.to(device)
self.weight_scale = self.weight_scale.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
self.loaded = True
def unload(self):
self.weight = None
self.weight_scale = None
self.weight_zero_point = None
self._orig_weight = None
if self.has_bias:
self.bias = None
self.loaded = False
class KLinearFP8(KLinearBase): class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight # this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI # Please load the weight file downloaded from KVCache.AI
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool has_bias: bool
weight: torch.Tensor weight: torch.Tensor
scale_w: torch.Tensor
bias: torch.Tensor bias: torch.Tensor
def __init__( def __init__(
self, self,
...@@ -468,6 +636,7 @@ LINEAR_MAP = { ...@@ -468,6 +636,7 @@ LINEAR_MAP = {
"KLinearTorch": KLinearTorch, "KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer, "KLinearCPUInfer": KLinearCPUInfer,
"KLinearFP8": KLinearFP8, "KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
} }
class KTransformersLinear(BaseInjectedModule, KLinearBase): class KTransformersLinear(BaseInjectedModule, KLinearBase):
......
...@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import ( ...@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MoE, DeepseekV2MoE,
) )
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
...@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag: if per_layer_prefill_flag:
causal_mask = None causal_mask = None
else: else:
if os.name == 'nt' or get_compute_capability()<8: if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
print("for Windows or GPU before ampere, use forward_windows") # print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn # only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
...@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f = 0 t_f = 0
for i, decoder_layer in enumerate(self.layers): for i, decoder_layer in enumerate(self.layers):
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if self.transfer_map is not None and i in self.transfer_map: if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream() prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i] cur_device = self.transfer_map[i]
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import triton import triton
import triton.language as tl import triton.language as tl
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
@triton.jit @triton.jit
def tanh(x): def tanh(x):
# Tanh is just a scaled sigmoid # Tanh is just a scaled sigmoid
...@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd( ...@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx # [TODO] work around shmem limit on MI3xx
# TODO: support hip # TODO: support hip
#if is_hip_ and Lk >= 576: if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576:
# BLOCK = 16 BLOCK = 16
if Lk == 576: if Lk == 576:
BLOCK_DMODEL = 512 BLOCK_DMODEL = 512
......
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import torch
import triton
import triton.language as tl
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
Out,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
kv_group_num: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
mask_d = offs_d < Lk
q = tl.load(
Q + off_q,
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
)
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
end_n = (
cur_batch_seq_len
if not IS_CAUSAL
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
)
for start_n in range(0, block_mask * end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
other=0.0,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if IS_CAUSAL:
qk += tl.where(
(start_n + offs_n[None, :] < cur_batch_seq_len)
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
0,
float("-inf"),
)
else:
qk += tl.where(
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
out_ptrs = Out + off_o
tl.store(
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
)
def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
BLOCK = 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK,
IS_CAUSAL=is_causal,
num_warps=num_warps,
num_stages=1,
Lk=Lk,
)
\ No newline at end of file
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
replace: replace:
class: ktransformers.operators.linear.KTransformersLinear class: ktransformers.operators.linear.KTransformersLinear
kwargs: kwargs:
generate_device: "cuda" generate_device: "cpu"
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearCPUInfer"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cpu"
prefill_device: "cuda"
generate_op: "KLinearQ8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
\ No newline at end of file
import torch
# 定义一个包含线性层的浮点模型
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
def forward(self, x):
return self.linear(x)
# 创建浮点模型实例
in_features = 64
out_features = 128
model_fp32 = LinearModel(in_features, out_features)
# 创建量化模型实例
model_int8 = torch.ao.quantization.quantize_dynamic(
model_fp32, # 原始浮点模型
{torch.nn.Linear}, # 要量化的层类型集合
dtype=torch.qint8 # 量化的目标数据类型
)
# 测试模型
batch_size = 32
input_fp32 = torch.randn(1, batch_size, in_features) # 生成随机输入数据
output_int8 = model_int8(input_fp32) # 通过量化模型运行数据
# 打印输出形状验证
print(f"输入形状: {input_fp32.shape}")
print(f"输出形状: {output_int8.shape}")
# 比较原始模型和量化模型的输出
with torch.no_grad():
output_fp32 = model_fp32(input_fp32)
print(f"FP32输出的前几个值: {output_fp32[0, :5]}")
print(f"INT8输出的前几个值: {output_int8[0, :5]}")
# 计算平均误差
error = torch.abs(output_fp32 - output_int8).mean().item()
print(f"平均绝对误差: {error}")
# 打印模型类型信息
print(f"量化前模型类型: {type(model_fp32.linear)}")
print(f"量化后模型类型: {type(model_int8.linear)}")
\ No newline at end of file
from __future__ import annotations
from enum import IntEnum, auto
from typing import Optional, Union, List
import torch
class GPUVendor(IntEnum):
NVIDIA = auto()
AMD = auto()
MooreThreads = auto()
MetaX = auto()
MUSA = auto()
Unknown = auto()
class DeviceManager:
"""
Device manager that provides a unified interface for handling different GPU vendors
"""
def __init__(self):
self.gpu_vendor = self._detect_gpu_vendor()
self.available_devices = self._get_available_devices()
def _detect_gpu_vendor(self) -> GPUVendor:
"""Detect GPU vendor type"""
if not torch.cuda.is_available():
# Check MUSA availability (assuming a musa module exists)
try:
import musa
if musa.is_available():
return GPUVendor.MUSA
except (ImportError, AttributeError):
pass
return GPUVendor.Unknown
device_name = torch.cuda.get_device_name(0).lower()
if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]):
return GPUVendor.NVIDIA
elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]):
return GPUVendor.AMD
elif any(name in device_name for name in ["mthreads", "moore", "mtt"]):
return GPUVendor.MooreThreads
elif any(name in device_name for name in ["metax", "meta"]):
return GPUVendor.MetaX
elif "musa" in device_name:
return GPUVendor.MUSA
# Backend check
try:
if hasattr(torch.version, 'hip') and torch.version.hip is not None:
return GPUVendor.AMD
elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:
return GPUVendor.NVIDIA
except:
pass
return GPUVendor.Unknown
def _get_available_devices(self) -> List[int]:
"""Get list of available device indices"""
devices = []
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
devices = list(range(torch.cuda.device_count()))
elif self.gpu_vendor == GPUVendor.MUSA:
try:
import musa
devices = list(range(musa.device_count()))
except (ImportError, AttributeError):
pass
return devices
def get_device_str(self, device_id: Union[int, str]) -> str:
"""
Get device string for the given device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
"""
if device_id == -1 or device_id == "cpu":
return "cpu"
if isinstance(device_id, int):
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
if device_id < torch.cuda.device_count():
return f"cuda:{device_id}"
elif self.gpu_vendor == GPUVendor.MUSA:
try:
import musa
if device_id < musa.device_count():
return f"musa:{device_id}"
except (ImportError, AttributeError):
pass
return "cpu"
def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:
"""
Convert device ID to torch.device object
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
device_str = self.get_device_str(device_id)
# Handle MUSA device
if device_str.startswith("musa:"):
try:
import musa
index = int(device_str.split(":")[-1])
return musa.device(index)
except (ImportError, ValueError, AttributeError):
return torch.device("cpu")
# Standard PyTorch device
return torch.device(device_str)
def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
device = self.to_torch_device(device_id)
return tensor.to(device)
def is_available(self, index: int = 0) -> bool:
"""
Check if device at specified index is available
Args:
index: Device index to check
Returns:
True if the device is available, False otherwise
"""
if index < 0:
return True # CPU is always available
return index in self.available_devices
def get_all_devices(self) -> List[int]:
"""
Get all available device indices
Returns:
List of available device indices (0, 1, 2, etc.)
"""
return self.available_devices
# Create global device manager instance
device_manager = DeviceManager()
# Convenience functions
def get_device(device_id: Union[int, str] = 0) -> torch.device:
"""
Get torch.device object for the specified device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
return device_manager.to_torch_device(device_id)
def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
return device_manager.move_tensor_to_device(tensor, device_id)
# Get devices
cpu_device = get_device(-1) # CPU using index -1
cpu_device2 = get_device("cpu") # CPU using string "cpu"
gpu0 = get_device(0) # First GPU
# Move tensors
x = torch.randn(3, 3)
x_gpu = to_device(x, 0) # Move to first GPU
x_cpu1 = to_device(x, -1) # Move to CPU using index -1
x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu"
\ No newline at end of file
...@@ -29,7 +29,7 @@ import torch.version ...@@ -29,7 +29,7 @@ import torch.version
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from cpufeature.extension import CPUFeature from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
try: try:
from torch_musa.utils.simple_porting import SimplePorting from torch_musa.utils.simple_porting import SimplePorting
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
...@@ -64,6 +64,70 @@ class VersionInfo: ...@@ -64,6 +64,70 @@ class VersionInfo:
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}" musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return musa_version return musa_version
def get_rocm_bare_metal_version(self, rocm_dir):
"""
Get the ROCm version from the ROCm installation directory.
Args:
rocm_dir: Path to the ROCm installation directory
Returns:
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
"""
try:
# Try using rocm_agent_enumerator to get version info
raw_output = subprocess.check_output(
[rocm_dir + "/bin/rocminfo", "--version"],
universal_newlines=True,
stderr=subprocess.STDOUT)
# Extract version number from output
match = re.search(r'(\d+\.\d+)', raw_output)
if match:
version_str = match.group(1)
version = parse(version_str)
rocm_version = f"{version.major}{version.minor}"
return rocm_version
except (subprocess.CalledProcessError, FileNotFoundError):
# If rocminfo --version fails, try alternative methods
pass
try:
# Try reading version from release file
with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f:
version_str = f.read().strip()
version = parse(version_str)
rocm_version = f"{version.major}{version.minor}"
return rocm_version
except (FileNotFoundError, IOError):
pass
# If all else fails, try to extract from directory name
dir_name = os.path.basename(os.path.normpath(rocm_dir))
match = re.search(r'rocm-(\d+\.\d+)', dir_name)
if match:
version_str = match.group(1)
version = parse(version_str)
rocm_version = f"{version.major}{version.minor}"
return rocm_version
# Fallback to extracting from hipcc version
try:
raw_output = subprocess.check_output(
[rocm_dir + "/bin/hipcc", "--version"],
universal_newlines=True,
stderr=subprocess.STDOUT)
match = re.search(r'HIP version: (\d+\.\d+)', raw_output)
if match:
version_str = match.group(1)
version = parse(version_str)
rocm_version = f"{version.major}{version.minor}"
return rocm_version
except (subprocess.CalledProcessError, FileNotFoundError):
pass
# If we still can't determine the version, raise an error
raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}")
def get_cuda_bare_metal_version(self, cuda_dir): def get_cuda_bare_metal_version(self, cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
...@@ -148,11 +212,13 @@ class VersionInfo: ...@@ -148,11 +212,13 @@ class VersionInfo:
cpu_instruct = self.get_cpu_instruct() cpu_instruct = self.get_cpu_instruct()
backend_version = "" backend_version = ""
if CUDA_HOME is not None: if CUDA_HOME is not None:
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}" backend_version = f""
elif MUSA_HOME is not None: elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
elif ROCM_HOME is not None:
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
else: else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version: if full_version:
return package_version return package_version
...@@ -247,8 +313,12 @@ class CMakeBuild(BuildExtension): ...@@ -247,8 +313,12 @@ class CMakeBuild(BuildExtension):
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"] cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
elif MUSA_HOME is not None: elif MUSA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"] cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
elif ROCM_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
else: else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
# log cmake_args
print("CMake args:", cmake_args)
build_args = [] build_args = []
if "CMAKE_ARGS" in os.environ: if "CMAKE_ARGS" in os.environ:
...@@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension): ...@@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
) )
if CUDA_HOME is not None: if CUDA_HOME is not None or ROCM_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [ ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp', 'ktransformers/ktransformers_ext/cuda/binding.cpp',
...@@ -338,7 +408,7 @@ if CUDA_HOME is not None: ...@@ -338,7 +408,7 @@ if CUDA_HOME is not None:
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [ 'nvcc': [
'-O3', '-O3',
'--use_fast_math', # '--use_fast_math',
'-Xcompiler', '-fPIC', '-Xcompiler', '-fPIC',
'-DKTRANSFORMERS_USE_CUDA', '-DKTRANSFORMERS_USE_CUDA',
] ]
...@@ -371,6 +441,7 @@ else: ...@@ -371,6 +441,7 @@ else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.") raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
setup( setup(
name=VersionInfo.PACKAGE_NAME,
version=VersionInfo().get_package_version(), version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[ ext_modules=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment