Unverified Commit ca1dc1e7 authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge branch 'main' into main

parents d3b45d57 505f4e2c
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
int block_idx, Backend *backend) { int block_idx, Backend *backend) {
// Timer start // Timer start
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "kvcache.h" #include "kvcache.h"
#include <chrono>
std::string ggml_type_to_string(ggml_type type) { std::string ggml_type_to_string(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
......
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
from typing import Tuple
import torch
import triton
import triton.language as tl
from triton import Config
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
Args:
x_ptr (triton.Pointer): Pointer to the input tensor.
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
Returns:
None
"""
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
with torch.cuda.device(x.device):
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c
\ No newline at end of file
...@@ -28,8 +28,9 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM ...@@ -28,8 +28,9 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
...@@ -53,7 +54,7 @@ default_optimize_rules = { ...@@ -53,7 +54,7 @@ default_optimize_rules = {
def local_chat( def local_chat(
model_path: str | None = None, model_path: str | None = None,
optimize_rule_path: str = None, optimize_config_path: str = None,
gguf_path: str | None = None, gguf_path: str | None = None,
max_new_tokens: int = 300, max_new_tokens: int = 300,
cpu_infer: int = Config().cpu_infer, cpu_infer: int = Config().cpu_infer,
...@@ -61,9 +62,9 @@ def local_chat( ...@@ -61,9 +62,9 @@ def local_chat(
prompt_file : str | None = None, prompt_file : str | None = None,
mode: str = "normal", mode: str = "normal",
force_think: bool = False, force_think: bool = False,
chunk_prefill_size: int = 8192
): ):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer Config().cpu_infer = cpu_infer
...@@ -94,12 +95,12 @@ def local_chat( ...@@ -94,12 +95,12 @@ def local_chat(
config, trust_remote_code=True, attn_implementation="flash_attention_2" config, trust_remote_code=True, attn_implementation="flash_attention_2"
) )
if optimize_rule_path is None: if optimize_config_path is None:
if config.architectures[0] in default_optimize_rules: if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0]) print("using default_optimize_rule for", config.architectures[0])
optimize_rule_path = default_optimize_rules[config.architectures[0]] optimize_config_path = default_optimize_rules[config.architectures[0]]
else: else:
optimize_rule_path = input( optimize_config_path = input(
"please input the path of your rule file(yaml file containing optimize rules):" "please input the path of your rule file(yaml file containing optimize rules):"
) )
...@@ -107,18 +108,18 @@ def local_chat( ...@@ -107,18 +108,18 @@ def local_chat(
gguf_path = input( gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
) )
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
try: try:
model.generation_config = GenerationConfig.from_pretrained(model_path) model.generation_config = GenerationConfig.from_pretrained(model_path)
except: except Exception as e:
gen_config = GenerationConfig( print(f"generation config can't auto create, make default. Message: {e}")
max_length=128, gen_config = GenerationConfig(
temperature=0.7, temperature=0.6,
top_p=0.9, top_p=0.95,
do_sample=True do_sample=True
) )
model.generation_config = gen_config model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path) # model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None: if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id model.generation_config.pad_token_id = model.generation_config.eos_token_id
...@@ -167,13 +168,17 @@ def local_chat( ...@@ -167,13 +168,17 @@ def local_chat(
if mode == 'long_context': if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
torch.bfloat16 if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
) # TODO: Remove this, replace dtype using config generated = prefill_and_generate(
generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
) )
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(local_chat) fire.Fire(local_chat)
\ No newline at end of file
...@@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache): ...@@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache):
page_idx = cache_position // self.page_size page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx] return k_out, self.page_table_list[layer_idx]
...@@ -172,8 +170,21 @@ class StaticCache(transformers.StaticCache): ...@@ -172,8 +170,21 @@ class StaticCache(transformers.StaticCache):
for layer_idx in range(len(self.key_cache)): for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0
def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
if self.is_MLA:
k_cache = self.key_cache[layer_idx]
k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()
else:
self.key_cache[layer_idx][..., start_pos:, :].zero_()
self.value_cache[layer_idx][..., start_pos:, :].zero_()
self.past_tokens[layer_idx] = start_pos
def get_max_cache_shape(self) -> Tuple[int, int, int, int]: def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache.""" """Returns the maximum shape of the cache."""
return self.max_cache_len return self.max_cache_len
\ No newline at end of file
...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states[:,-1:,:]).float()
logits = logits[:,-1,:].unsqueeze(0).float()
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): ...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.device)) logits = self.lm_head(hidden_states[:,-1:,:])
logits = logits.float() logits = logits.float()
loss = None loss = None
......
...@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): ...@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.max_position_embeddings, orig_module.base orig_module.dim, orig_module.max_position_embeddings, orig_module.base
...@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule): ...@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device
...@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding): ...@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
...@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
...@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
# **kwargs, # **kwargs,
# ): # ):
# BaseInjectedModule.__init__( # BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs # self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
# ) # )
# self.generate_device = generate_device # self.generate_device = generate_device
# self.prefill_device = prefill_device # self.prefill_device = prefill_device
...@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule): ...@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device
...@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding( ...@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
......
This diff is collapsed.
...@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module): ...@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs): **kwargs):
nn.Module.__init__(self) nn.Module.__init__(self)
nn.Module.__setattr__(self, "orig_module", orig_module) nn.Module.__setattr__(self, "orig_module", orig_module)
object.__setattr__(self, "key", key) object.__setattr__(self, "key", key)
object.__setattr__(self, "gguf_loader", gguf_loader) object.__setattr__(self, "gguf_loader", gguf_loader)
object.__setattr__(self, "config", config) object.__setattr__(self, "config", config)
object.__setattr__(self, "device", device) object.__setattr__(self, "prefill_device", prefill_device)
object.__setattr__(self, "generate_device", generate_device)
object.__setattr__(self, "device", generate_device)
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__, # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
......
...@@ -18,6 +18,7 @@ import torch.nn.functional as F ...@@ -18,6 +18,7 @@ import torch.nn.functional as F
import torch import torch
import sys, os import sys, os
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
...@@ -118,6 +119,7 @@ class KExpertsCPU(KExpertsBase): ...@@ -118,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
output_cpu:Tensor = None output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER = CPUInfer(Config().cpu_infer) CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__( def __init__(
self, self,
...@@ -131,6 +133,9 @@ class KExpertsCPU(KExpertsBase): ...@@ -131,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
**kwargs **kwargs
): ):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self.gguf_loader = gguf_loader
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU" assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts self.n_routed_experts = n_routed_experts
self.out_device = out_device self.out_device = out_device
...@@ -154,7 +159,7 @@ class KExpertsCPU(KExpertsBase): ...@@ -154,7 +159,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr = ctypes.addressof( down_ptr = ctypes.addressof(
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
) )
# print(self.gate_qtype, self.up_qtype, self.down_qtype) #print(self.gate_type, self.up_type, self.down_type)
n_routed_experts = self.n_routed_experts n_routed_experts = self.n_routed_experts
# n_routed_experts = len(self.orig_module) # n_routed_experts = len(self.orig_module)
moe_config = MOEConfig( moe_config = MOEConfig(
...@@ -225,6 +230,7 @@ class KExpertsCPU(KExpertsBase): ...@@ -225,6 +230,7 @@ class KExpertsCPU(KExpertsBase):
return return
def load_weights(self, override_key: str | None = None, device: str = "cpu"): def load_weights(self, override_key: str | None = None, device: str = "cpu"):
# TODO: support Bias
res = {} res = {}
if override_key is not None: if override_key is not None:
keys = override_key keys = override_key
...@@ -239,7 +245,16 @@ class KExpertsCPU(KExpertsBase): ...@@ -239,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
down_type = None down_type = None
for key in keys: for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: if self.gguf_loader.safetensor_loader is not None:
# using a temp ugly way to temprary load the tensor
gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
...@@ -288,6 +303,8 @@ class KExpertsMarlin(KExpertsBase): ...@@ -288,6 +303,8 @@ class KExpertsMarlin(KExpertsBase):
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
self.device = device self.device = device
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
# create empty marlin experts according to the number of experts per token # create empty marlin experts according to the number of experts per token
# up # up
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
...@@ -299,17 +316,34 @@ class KExpertsMarlin(KExpertsBase): ...@@ -299,17 +316,34 @@ class KExpertsMarlin(KExpertsBase):
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
if w is None: w = self.load_weights()[self.key] if w is None:
w = self.load_weights()
if isinstance(w, dict): load_by_experts = True
self.gate = w["gate"]
self.up = (w["up"]) if load_by_experts:
self.down = (w["down"]) if isinstance(w, dict):
for i in range(self.expert_num): self.gate = w["gate"]
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device) self.up = (w["up"])
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device) self.down = (w["down"])
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device) for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"):
self.loaded_experts_idx.append(i) up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device)
self.up_projs[i].load(nn.Parameter(up_weights), device=device)
self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)
self.down_projs[i].load(nn.Parameter(down_weights), device=device)
self.loaded_experts_idx.append(i)
else:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
return return
def unload(self): def unload(self):
...@@ -329,20 +363,13 @@ class KExpertsMarlin(KExpertsBase): ...@@ -329,20 +363,13 @@ class KExpertsMarlin(KExpertsBase):
gate = None gate = None
up = None up = None
down = None down = None
gate_type = None
up_type = None
down_type = None
for key in keys: for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight") gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] res = {"gate": gate, "up": up, "down": down}
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
return res return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
...@@ -381,6 +408,7 @@ class KExpertsMarlin(KExpertsBase): ...@@ -381,6 +408,7 @@ class KExpertsMarlin(KExpertsBase):
return final_hidden_states.to(dtype=org_dtype, device=org_device) return final_hidden_states.to(dtype=org_dtype, device=org_device)
# untested, CUDA OOM
class KExpertsTorch(KExpertsBase): class KExpertsTorch(KExpertsBase):
expert_num: int expert_num: int
loaded_experts_idx: list[int] loaded_experts_idx: list[int]
...@@ -402,19 +430,39 @@ class KExpertsTorch(KExpertsBase): ...@@ -402,19 +430,39 @@ class KExpertsTorch(KExpertsBase):
# self.loaded_experts_idx = [] # self.loaded_experts_idx = []
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
self.device = device self.device = device
self.gate = None self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
self.up = None self.gate = [None for _ in range(self.expert_num)]
self.donw = None self.up = [None for _ in range(self.expert_num)]
self.down = [None for _ in range(self.expert_num)]
self.dtype = torch.get_default_dtype() self.dtype = torch.get_default_dtype()
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device if device is None: device = self.device
if w is None: w = self.load_weights(device=device)[self.key] if w is None:
w = self.load_weights()
if isinstance(w, dict): load_by_experts = True
self.gate = w["gate"].to(device=device, dtype=self.dtype)
self.up = w["up"].to(device=device, dtype=self.dtype) if load_by_experts:
self.down = w["down"].to(device=device, dtype=self.dtype) if isinstance(w, dict):
for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"):
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device)
self.up[i] = up_weights
self.gate[i] = gate_weights
self.down[i] = down_weights
else:
if isinstance(w, dict):
for i in range(self.expert_num):
self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype)
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
self.up = torch.stack(self.up, dim=0)
self.gate = torch.stack(self.gate, dim=0)
self.down = torch.stack(self.down, dim=0)
return
def unload(self): def unload(self):
if self.gate is not None: if self.gate is not None:
...@@ -422,6 +470,25 @@ class KExpertsTorch(KExpertsBase): ...@@ -422,6 +470,25 @@ class KExpertsTorch(KExpertsBase):
self.up = None self.up = None
self.down = None self.down = None
def load_weights(self, override_key: str | None = None):
res = {}
if override_key is not None:
keys = override_key
else:
keys = [self.key]
gate = None
up = None
down = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
org_device = hidden_states_cpu.device org_device = hidden_states_cpu.device
...@@ -478,7 +545,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): ...@@ -478,7 +545,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
generate_device: str = "cpu", generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU", generate_op: str | None = "KExpertsCPU",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None: if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
...@@ -582,7 +649,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock): ...@@ -582,7 +649,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = ( y = (
self.moe_on_cpuinfer( self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert hidden_states_expert, selected_experts_expert, routing_weights_expert
) )
.view(*orig_shape) .view(*orig_shape)
...@@ -601,8 +668,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock): ...@@ -601,8 +668,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return y, router_logits return y, router_logits
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs
...@@ -672,7 +738,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): ...@@ -672,7 +738,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10: elif hidden_states.size(0) > 10:
# TODO may bugs here # TODO may bugs here
y = ( y = (
...@@ -692,8 +758,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): ...@@ -692,8 +758,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
return y return y
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs
...@@ -773,7 +838,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -773,7 +838,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10: elif hidden_states.size(0) > 10:
# TODO may bugs here # TODO may bugs here
y = ( y = (
...@@ -793,8 +858,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -793,8 +858,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
return y return y
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs
...@@ -881,7 +945,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): ...@@ -881,7 +945,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = ( y = (
self.moe_on_cpuinfer( self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert hidden_states_expert, selected_experts_expert, routing_weights_expert
) )
.view(*orig_shape) .view(*orig_shape)
...@@ -900,8 +964,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): ...@@ -900,8 +964,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
return y, router_logits return y, router_logits
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs
......
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.2
'''
import torch
flashinfer_enabled = False
try:
import flashinfer
flashinfer_enabled = True
print("found flashinfer")
except ImportError:
print("flashinfer not found, use triton for linux")
import math
def attention_ref(
batch_size,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool,
sm_scale: float,
) -> torch.Tensor:
qo_len = q.shape[0] // batch_size
kv_len = k.shape[0] // batch_size
num_qo_heads = q.shape[1]
head_dim_qk = q.shape[2]
head_dim_vo = v.shape[2]
logits = (
torch.einsum(
"bmhd,bnhd->bhmn",
q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),
k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),
)
* sm_scale
)
#print("attn weights", logits)
if causal:
mask = (
torch.arange(kv_len - qo_len, kv_len).unsqueeze(1)
>= torch.arange(0, kv_len).unsqueeze(0)
).to(q.device)
else:
mask = torch.ones(qo_len, kv_len).to(q.device)
logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)
p = torch.softmax(logits, dim=-1)
o_ref = (
torch.einsum(
"bhmn,bnhd->bmhd",
p,
v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),
)
.contiguous()
.view(batch_size * qo_len, num_qo_heads, head_dim_vo)
.to(q)
)
return o_ref, lse_ref * math.log2(math.e)
class MLAWrapper():
def __init__(self,
max_batch_size,
max_pages,
use_cuda_graph = True,
device = "cuda",
):
self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device)
self.max_batch_size = max_batch_size
self.max_pages = max_pages
if use_cuda_graph:
if self.max_batch_size == 1:
self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device)
self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
else:
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
else:
self.qo_indptr_buf = None
self.kv_indptr_buf = None
self.kv_indices_buf = None
self.kv_len_arr_buf = None
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.float_workspace_buffer,
use_cuda_graph=False,
qo_indptr=self.qo_indptr_buf,
kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf,
)
self.need_plan = True
def plan(self,
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
sm_scale,
q_data_type,
kv_data_type,
):
if qo_indptr is None:
assert self.max_batch_size == 1
qo_indptr = self.qo_indptr_buf
if kv_indptr is None:
assert self.max_batch_size == 1
kv_indptr = self.kv_indptr_buf
if kv_indices is None:
assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf
self.wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
True, # causal
sm_scale,
q_data_type,
kv_data_type,
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton():
wrappers:dict = {}
@classmethod
def get_instance(cls, device, *args, **kwargs)->MLAWrapper:
if device not in cls.wrappers:
cls.make_instance(device, *args, **kwargs)
return cls.wrappers[device]
@classmethod
def make_instance(cls, device, *args, **kwargs):
cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device)
@classmethod
def plan_all(cls, qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
sm_scale,
q_data_type,
kv_data_type,):
for device, wrapper in cls.wrappers.items():
kv_len_arr_cur_device = kv_len_arr.to(device)
wrapper.plan(qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr_cur_device,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
sm_scale,
q_data_type,
kv_data_type,)
wrapper.need_plan = False
@classmethod
def need_plan_all(cls):
for device, wrapper in cls.wrappers.items():
wrapper.need_plan = True
@classmethod
def reset_buffer(cls):
for device, wrapper in cls.wrappers.items():
wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here.
@classmethod
def update_buffer(cls, max_pages):
for device, wrapper in cls.wrappers.items():
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
max_batch_size = 1
max_pages = 128
page_size = 64
num_heads = 128
kv_len = 4023
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
wrapper = MLAWrapperSingleton.get_instance(
"cuda",
max_batch_size,
max_pages,
)
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
print(attn_output.shape)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
kv_len = 6789
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
graph.replay()
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:kv_len].shape)
print(v[:kv_len].shape)
attn_ref, lse_ref = attention_ref(
max_batch_size,
torch.cat([q_nope, q_pe], dim=-1),
k[:kv_len],
v[:kv_len],
True,
192 ** (-0.5)
)
print(attn_ref.shape)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past")
\ No newline at end of file
...@@ -67,7 +67,14 @@ class KMoEGateBase(ABC): ...@@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
for key in keys: for key in keys:
key = ".".join(key.split(".")[:-1]) key = ".".join(key.split(".")[:-1])
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info: if self.gguf_loader.safetensor_loader is not None:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
weight_type = weight.dtype
e_score_correction_bias_type = e_score_correction_bias.dtype
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
tensors = self.load_multi(key, targets, device=device) tensors = self.load_multi(key, targets, device=device)
weight = tensors[".ffn_gate_inp.weight"] weight = tensors[".ffn_gate_inp.weight"]
...@@ -93,11 +100,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): ...@@ -93,11 +100,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module = None, orig_module: nn.Module = None,
generate_device: str = "cuda",
prefill_device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device
...@@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): ...@@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else: else:
raise ValueError("Invalid weight type") raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device) self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
def unload(self): def unload(self):
if self.weight is not None: if self.weight is not None:
......
...@@ -21,10 +21,12 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ...@@ -21,10 +21,12 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
MarlinWorkspace, MarlinWorkspace,
marlin_quantize, marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MAX_PARALLEL,
) )
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sys, os import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
...@@ -64,6 +66,8 @@ class KLinearBase(ABC): ...@@ -64,6 +66,8 @@ class KLinearBase(ABC):
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill.
@abstractmethod @abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
pass pass
...@@ -75,7 +79,13 @@ class KLinearBase(ABC): ...@@ -75,7 +79,13 @@ class KLinearBase(ABC):
keys = [self.key] keys = [self.key]
for key in keys: for key in keys:
if key + ".weight" in self.gguf_loader.tensor_file_map: if self.gguf_loader.safetensor_loader is not None:
# using safetensor_loader
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map:
tensors = self.load_multi(key, ["weight", "bias"], device=device) tensors = self.load_multi(key, ["weight", "bias"], device=device)
tensor = tensors["weight"] tensor = tensors["weight"]
...@@ -119,7 +129,7 @@ class KLinearTorch(KLinearBase): ...@@ -119,7 +129,7 @@ class KLinearTorch(KLinearBase):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False self.has_bias = False
self.dtype = torch.get_default_dtype() self.dtype = torch.get_default_dtype()
self.w = None self.weight = None
self.has_bias = False self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -127,44 +137,100 @@ class KLinearTorch(KLinearBase): ...@@ -127,44 +137,100 @@ class KLinearTorch(KLinearBase):
out_device = x.device out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
x = x @ self.w x = x @ self.weight
if self.has_bias: if self.has_bias:
x = x + self.bias x = x + self.bias
x = x.to(dtype=dtype, device=out_device) x = x.to(dtype=dtype, device=out_device)
return x return x
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
if w is None: w = self.load_weight(device=device) if w is None: w = self.load_weight(device=device)
# else: self.out_features = w.shape[0], self.in_features = w.shape[1] # else: self.out_features = w.shape[0], self.in_features = w.shape[1]
if isinstance(w, nn.Parameter): if isinstance(w, nn.Parameter):
try: try:
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
except: except:
self.w = w.to(dtype=self.dtype).T self.weight = w.to(dtype=self.dtype).T
self.has_bias = False self.has_bias = False
elif isinstance(w, tuple): elif isinstance(w, tuple):
try: try:
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
except: except:
self.w = w[0].to(dtype=self.dtype).T self.weight = w[0].to(dtype=self.dtype).T
self.bias = w[1].to(dtype=self.dtype) self.bias = w[1].to(dtype=self.dtype)
self.has_bias = True self.has_bias = True
else: else:
raise ValueError("Invalid weight type") raise ValueError("Invalid weight type")
# self.linear = self.linear.to(device) # self.linear = self.linear.to(device)
self.w = self.w.to(device) self.weight = self.weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
self.loaded = True
def unload(self): def unload(self):
if self.w is not None: if self.weight is not None:
self.w = None self.weight = None
if self.has_bias: if self.has_bias:
self.bias = None self.bias = None
class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
weight: torch.Tensor
scale_w: torch.Tensor
bias: torch.Tensor
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
block_size: int = 128,
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.block_size = block_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
orig_dtype = x.dtype
x_quantized, scale_x = act_quant(x, self.block_size)
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
return y.to(dtype=orig_dtype)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None:
w = self.load_weight(device=device)
### TODO fit weight_inv format
if isinstance(w, tuple):
self.weight = w[0].to(device)
self.weight_scale_inv = w[1].to(device)
self.has_bias = False
else:
raise ValueError("Invalid weight type")
self.weight = self.weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
def unload(self):
if self.weight is not None:
self.weight = None
if self.has_bias:
self.bias = None
class KLinearMarlin(KLinearBase): class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor marlin_q_w: torch.Tensor
marlin_s: torch.Tensor marlin_s: torch.Tensor
...@@ -190,20 +256,36 @@ class KLinearMarlin(KLinearBase): ...@@ -190,20 +256,36 @@ class KLinearMarlin(KLinearBase):
self.group_size = group_size self.group_size = group_size
self.act_order = act_order self.act_order = act_order
self.is_k_full = is_k_full self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None: if w is None:
w = self.load_weight(device=device) w = self.load_weight(device=device)
if isinstance(w, nn.Parameter): if isinstance(w, nn.Parameter):
# pad weight # pad weight
weight = w.view(self.out_features, self.in_features).T weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False self.has_bias = False
elif isinstance(w, tuple): elif isinstance(w, tuple):
w = list(w) w = list(w)
weight = w[0].view(self.out_features, self.in_features).T weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1] self.bias = w[1]
self.has_bias = True self.has_bias = True
else: else:
...@@ -211,19 +293,27 @@ class KLinearMarlin(KLinearBase): ...@@ -211,19 +293,27 @@ class KLinearMarlin(KLinearBase):
weight = weight.to(device) weight = weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear # Pack Marlin linear
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order weight, self.num_bits, self.group_size, self.act_order
) )
self.workspace = MarlinWorkspace( self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
) )
self.weight = marlin_q_w # modeling_xxx.py may use linear.weight
self.marlin_q_w = marlin_q_w self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s self.marlin_s = marlin_s
self.g_idx = g_idx self.g_idx = g_idx
self.sort_indices = sort_indices self.sort_indices = sort_indices
self.k = weight.shape[0] self.k = weight.shape[0]
self.n = weight.shape[1] self.n = weight.shape[1]
self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Only support input x as BF16 and FP16 # Only support input x as BF16 and FP16
...@@ -231,6 +321,11 @@ class KLinearMarlin(KLinearBase): ...@@ -231,6 +321,11 @@ class KLinearMarlin(KLinearBase):
orig_shape = list(x.shape) orig_shape = list(x.shape)
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1]) x = x.reshape(-1, orig_shape[-1])
x = x.reshape(-1, x.shape[-1])
if self.padding:
padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)
padding_input[:,:self.orin_in_features] = x
x = padding_input
marlin_s = self.marlin_s.to(x.dtype) marlin_s = self.marlin_s.to(x.dtype)
x = KTransformersOps.gptq_marlin_gemm( x = KTransformersOps.gptq_marlin_gemm(
x, x,
...@@ -245,6 +340,11 @@ class KLinearMarlin(KLinearBase): ...@@ -245,6 +340,11 @@ class KLinearMarlin(KLinearBase):
x.shape[-1], x.shape[-1],
self.is_k_full, self.is_k_full,
) )
if self.padding:
x = x[:,:self.orin_out_features]
orig_shape[-1] = self.orin_out_features
else:
orig_shape[-1] = self.out_features
if self.has_bias: if self.has_bias:
x = x + self.bias x = x + self.bias
orig_shape[-1] = self.n orig_shape[-1] = self.n
...@@ -365,7 +465,8 @@ class KLinearCPUInfer(KLinearBase): ...@@ -365,7 +465,8 @@ class KLinearCPUInfer(KLinearBase):
LINEAR_MAP = { LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin, "KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch, "KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer "KLinearCPUInfer": KLinearCPUInfer,
"KLinearFP8": KLinearFP8,
} }
class KTransformersLinear(BaseInjectedModule, KLinearBase): class KTransformersLinear(BaseInjectedModule, KLinearBase):
...@@ -382,29 +483,18 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -382,29 +483,18 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
prefill_op: str| None = "KLinearTorch", prefill_op: str| None = "KLinearTorch",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else:
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
self.prefill_linear = None self.prefill_linear = None
if generate_op is not None: if generate_op is not None:
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.generate_op = "KLinearTorch"
self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs)
else:
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = None self.generate_linear = None
self.mode = InferenceState.UNLOAD self.mode = InferenceState.UNLOAD
...@@ -412,10 +502,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -412,10 +502,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
def forward(self, x): def forward(self, x):
if self.mode == InferenceState.PREFILL: if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized" assert self.prefill_linear is not None, "cpu linear is not initialized"
return self.prefill_linear.forward(x) y = self.prefill_linear.forward(x)
else: else:
assert self.generate_linear is not None, "gpu linear is not initialized" assert self.generate_linear is not None, "gpu linear is not initialized"
return self.generate_linear.forward(x) y = self.generate_linear.forward(x)
return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
if not mode: if not mode:
...@@ -424,11 +515,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -424,11 +515,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
if mode == InferenceState.PREFILL: if mode == InferenceState.PREFILL:
self.generate_linear.unload() self.generate_linear.unload()
self.prefill_linear.load(w=w) self.prefill_linear.load(w=w)
self.device = self.prefill_linear.device self.device = self.prefill_linear.device
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.GENERATE: elif mode == InferenceState.GENERATE:
self.prefill_linear.unload() self.prefill_linear.unload()
self.generate_linear.load(w=w) self.generate_linear.load(w=w)
self.device = self.generate_linear.device self.device = self.generate_linear.device
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.UNLOAD: elif mode == InferenceState.UNLOAD:
self.prefill_linear.unload() self.prefill_linear.unload()
self.generate_linear.unload() self.generate_linear.unload()
......
...@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import ( ...@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState, get_compute_capability
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import ( from ktransformers.models.modeling_llama import (
...@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag: if per_layer_prefill_flag:
causal_mask = None causal_mask = None
else: else:
causal_mask = self._update_causal_mask( if os.name == 'nt' or get_compute_capability()<8:
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions print("for Windows or GPU before ampere, use forward_windows")
) # only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
else:
causal_mask = None
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo ...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
gguf_loader=GGUFLoader(gguf_path) gguf_loader=GGUFLoader(gguf_path)
with torch.device("meta"): with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader) inject(module, optimize_config, model_config, gguf_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader) load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader module.gguf_loader = gguf_loader
del_meta(module) del_meta(module)
......
...@@ -219,8 +219,20 @@ ...@@ -219,8 +219,20 @@
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)|(^lm_head)" name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment