Commit 4d1d561d authored by chenxl's avatar chenxl
Browse files

[feature] release 0.1.3

parent 67f8b370
......@@ -7,16 +7,22 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import torch
from torch import nn
import warnings
import torch.nn.functional as F
from ktransformers.operators.models import KLlamaModel
from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
logger = logging.getLogger("attention")
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask: Optional[torch.Tensor] = None
def __init__(self,
key: str,
......@@ -24,10 +30,12 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
......@@ -157,9 +165,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
chunck_size = 256 # TODO, generate chunck_size automatically.
if q_len <= chunck_size:
if q_len <= self.chunck_size:
return self.forward_chunck(
hidden_states,
attention_mask,
......@@ -176,24 +183,170 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cur_idx = 0
while cur_idx < q_len:
if attention_mask is not None:
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + chunck_size, q_len), ...]
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
else:
chunk_mask = None
# generate chunk_mask automatically.
self.attn_mask = \
torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
if self.attn_mask is None \
else self.attn_mask
self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \
-1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\
[:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
self.attn_mask[:, :, :, :cur_idx] = 0
chunck_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))
cur_output, _, _ = self.forward_chunck(
hidden_states[:, cur_idx:min(cur_idx + chunck_size, q_len), ...],
chunk_mask,
position_ids[:, cur_idx:min(cur_idx + chunck_size, q_len)],
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
chunck_mask,
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
past_key_value,
output_attentions,
use_cache,
cache_position[cur_idx:min(cur_idx + chunck_size, q_len)],
cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
**kwargs
)
cur_idx += chunck_size
cur_idx += self.chunck_size
if attn_output is None:
attn_output = cur_output
else:
attn_output = torch.cat((attn_output, cur_output), dim=-2)
return attn_output, None, past_key_value
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class KLlamaAttention(BaseInjectedModule):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
if q_len == 1:
position_ids = position_ids[0][-1].unsqueeze(0).unsqueeze(0)
query_states = query_states[:, :, -1:]
key_states = key_states[:, :, -1:]
attn_output = KLlamaModel.dynamic_sdpa.apply(
self.layer_idx,
bsz,
position_ids[0][0],
query_states.transpose(1, 2).to(torch.float16),
key_states.transpose(1, 2).to(torch.float16),
value_states.transpose(1, 2).to(torch.float16),
mode="prefill" if q_len > 1 else "generate",
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
\ No newline at end of file
#!/usr/bin/env python
# coding=utf-8
"""
Description : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference
with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring
and managing key-value caches, updating and retrieving cache data, and handling attention
operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies
(e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization
on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration.
These classes facilitate efficient caching and memory management for deep learning models
that leverage key-value attention mechanisms, particularly on CPU-based systems.
Author : djw
Date : 2024-08-26 23:25:24
Version : 1.0.0
LastEditors : djw
LastEditTime : 2024-08-26 23:25:24
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import sys, os
from typing import Any
import torch
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.server.config.config import Config
class CPUInferKVCache:
def __init__(
self,
layer_num: int = 32,
kv_head_num: int = 8,
q_head_num: int = 32,
head_dim: int = 128,
block_len: int = 256,
anchor_num: int = 4,
anchor_type: str = "FIXED",
kv_type: str = "Q4_0",
retrieval_type: str = "SHARED",
layer_step: int = 1,
token_step: int = 1,
layer_offset: int = 0,
max_thread_num: int = 32,
max_batch_size: int = 4,
max_block_num: int = 512,
):
if anchor_type == "FIXED":
anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED
elif anchor_type == "QUEST":
anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST
elif anchor_type == "DYNAMIC":
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
elif anchor_type == "BLOCK_MEAN":
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN
elif anchor_type == "BLOCK_MAX":
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX
else:
raise ValueError(f"Unknown anchor type: {anchor_type}")
if kv_type == "FP16":
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
elif kv_type == "FP32":
assert False, "FP32 is not supported yet."
kv_type = cpuinfer_ext.kvcache.ggml_type.FP32
elif kv_type == "Q4_0":
kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0
elif kv_type == "Q8_0":
kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0
else:
raise ValueError(f"Unknown kv type: {kv_type}")
if retrieval_type == "SHARED":
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
elif retrieval_type == "INDIVIDUAL":
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD
elif retrieval_type == "SEPARATE":
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD
self.config = cpuinfer_ext.kvcache.KVCacheConfig(
layer_num,
kv_head_num,
q_head_num,
head_dim,
block_len,
anchor_num,
anchor_type,
kv_type,
retrieval_type,
layer_step,
token_step,
layer_offset,
max_block_num,
max_batch_size,
max_thread_num,
)
self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config)
def load_kvcache(self, tensor_file_path: str):
if not os.path.exists(tensor_file_path):
raise FileNotFoundError(f"The file {tensor_file_path} does not exist.")
return self.kvcache.load_kvcache(tensor_file_path,)
def dump_kvcache(
self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str
):
assert (
block_table.dim() == 1
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
assert (
cache_total_len > 0
and cache_total_len <= self.config.block_len * block_table.size(0)
), "cache_total_len: {}".format(cache_total_len)
if not os.path.exists(os.path.dirname(tensor_file_path)):
os.makedirs(os.path.dirname(tensor_file_path))
return self.kvcache.dump_kvcache(
block_table.data_ptr(),
cache_total_len,
tensor_file_path,
)
def update_cache_total_len(self, cache_total_len: int):
assert cache_total_len > 0, "cache_total_len: {}".format(cache_total_len)
self.kvcache.update_cache_total_len(cache_total_len)
# q_in: (bsz, q_len, q_head_num, head_dim)
# output: (bsz, q_len, q_head_num, head_dim)
# attn_lse: (bsz, q_len, q_head_num)
# block_table: (bsz, max_block_num)
def attn(
self,
q_in: torch.Tensor,
output: torch.Tensor,
attn_lse: torch.Tensor,
layer_idx: int,
generate_token_idx: int,
block_table: torch.Tensor | None = None,
cache_seqlens: torch.Tensor | None = None,
pick_block_num: int | None = None,
init_block_num: int | None = None,
local_block_num: int | None = None,
):
assert (
q_in.dim() == 4
and q_in.size(2) == self.config.q_head_num
and q_in.size(3) == self.config.head_dim
and q_in.dtype == torch.float16
and q_in.is_contiguous()
and q_in.device == torch.device("cpu")
), "q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device
)
batch_size = q_in.size(0)
q_len = q_in.size(1)
assert (block_table is None) or (
block_table.dim() == 2
and block_table.size(0) == batch_size
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
max_block_num = block_table.size(1) if block_table is not None else 0
assert (
output.dim() == 4
and output.size(0) == batch_size
and output.size(2) == self.config.q_head_num
and output.size(1) == q_len
and output.size(3) == self.config.head_dim
and output.dtype == torch.float16
and output.is_contiguous()
and output.device == torch.device("cpu")
), "output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
output.dim(),
output.size(),
output.dtype,
output.is_contiguous(),
output.device,
)
assert (
attn_lse.dim() == 3
and attn_lse.size(0) == batch_size
and attn_lse.size(1) == q_len
and attn_lse.size(2) == self.config.q_head_num
and attn_lse.dtype == torch.float32
and attn_lse.is_contiguous()
and attn_lse.device == torch.device("cpu")
), "attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
attn_lse.dim(),
attn_lse.size(),
attn_lse.dtype,
attn_lse.is_contiguous(),
attn_lse.device,
)
assert (
layer_idx >= 0 and layer_idx < self.config.layer_num
), "layer_idx: {}".format(layer_idx)
assert (cache_seqlens is None) or (
cache_seqlens.dim() == 1
and cache_seqlens.size(0) == batch_size
and cache_seqlens.dtype == torch.int
and cache_seqlens.is_contiguous()
and cache_seqlens.device == torch.device("cpu")
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
cache_seqlens.dim(),
cache_seqlens.size(),
cache_seqlens.dtype,
cache_seqlens.is_contiguous(),
cache_seqlens.device,
)
return self.kvcache.attn(
q_in.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
block_table.data_ptr() if block_table is not None else 0,
cache_seqlens.data_ptr() if cache_seqlens is not None else 0,
pick_block_num,
init_block_num,
local_block_num,
)
# k_in: (block_len, kv_head_num, head_dim)
# v_in: (block_len, kv_head_num, head_dim)
def update_kvcache_one_block_fp16(
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
):
assert (
k_in.dim() == 3
and k_in.size(1) == self.config.block_len
and k_in.size(0) == self.config.kv_head_num
and k_in.size(2) == self.config.head_dim
and k_in.dtype == torch.float16
and k_in.is_contiguous()
and k_in.device == torch.device("cpu")
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
)
assert (
v_in.dim() == 3
and v_in.size(1) == self.config.block_len
and v_in.size(0) == self.config.kv_head_num
and v_in.size(2) == self.config.head_dim
and v_in.dtype == torch.float16
and v_in.is_contiguous()
and v_in.device == torch.device("cpu")
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.update_one_block_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_id,
block_idx,
)
def get_kvcache_one_block_fp16(
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
):
assert (
k_in.dim() == 3
and k_in.size(1) == self.config.block_len
and k_in.size(0) == self.config.kv_head_num
and k_in.size(2) == self.config.head_dim
and k_in.dtype == torch.float16
and k_in.is_contiguous()
and k_in.device == torch.device("cpu")
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
)
assert (
v_in.dim() == 3
and v_in.size(1) == self.config.block_len
and v_in.size(0) == self.config.kv_head_num
and v_in.size(2) == self.config.head_dim
and v_in.dtype == torch.float16
and v_in.is_contiguous()
and v_in.device == torch.device("cpu")
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.get_one_block_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_id,
block_idx,
)
def update_importance_one_block(
self, importance: torch.Tensor, layer_id: int, block_idx: int
):
assert (
importance.dim() == 1
and importance.size(0) == self.config.block_len
and importance.dtype == torch.float16
and importance.is_contiguous()
and importance.device == torch.device("cpu")
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
importance.dim(),
importance.size(),
importance.dtype,
importance.is_contiguous(),
importance.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.update_importance_one_block(
importance.data_ptr(),
layer_id,
block_idx,
)
def get_importance_one_block(
self, importance: torch.Tensor, layer_id: int, block_idx: int
):
assert (
importance.dim() == 1
and importance.size(0) == self.config.block_len
and importance.dtype == torch.float16
and importance.is_contiguous()
and importance.device == torch.device("cpu")
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
importance.dim(),
importance.size(),
importance.dtype,
importance.is_contiguous(),
importance.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.get_importance_one_block(
importance.data_ptr(),
layer_id,
block_idx,
)
def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int):
assert (
anchor.dim() == 3
and anchor.size(0) == self.config.kv_head_num
and anchor.size(1) == self.config.anchor_num
and anchor.size(2) == self.config.head_dim
and anchor.dtype == torch.float16
and anchor.is_contiguous()
and anchor.device == torch.device("cpu")
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
anchor.dim(),
anchor.size(),
anchor.dtype,
anchor.is_contiguous(),
anchor.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.get_anchor_one_block(
anchor.data_ptr(),
layer_id,
block_idx,
)
def update_anchor_one_block(
self, anchor: torch.Tensor, layer_id: int, block_idx: int
):
assert (
anchor.dim() == 3
and anchor.size(0) == self.config.kv_head_num
and anchor.size(1) == self.config.anchor_num
and anchor.size(2) == self.config.head_dim
and anchor.dtype == torch.float16
and anchor.is_contiguous()
and anchor.device == torch.device("cpu")
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
anchor.dim(),
anchor.size(),
anchor.dtype,
anchor.is_contiguous(),
anchor.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.update_anchor_one_block(
anchor.data_ptr(),
layer_id,
block_idx,
)
def calc_anchor_all_layers(
self,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
):
assert (
block_table.dim() == 2
and block_table.size(0) == cache_seqlens.size(0)
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
assert (
cache_seqlens.dim() == 1
and cache_seqlens.dtype == torch.int
and cache_seqlens.is_contiguous()
and cache_seqlens.device == torch.device("cpu")
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
cache_seqlens.dim(),
cache_seqlens.size(),
cache_seqlens.dtype,
cache_seqlens.is_contiguous(),
cache_seqlens.device,
)
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
return self.kvcache.calc_anchor_all_layers(
block_table.data_ptr(),
cache_seqlens.data_ptr(),
batch_size,
max_block_num,
)
def clear_importance_all_layers(
self,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
):
assert (
block_table.dim() == 2
and block_table.size(0) == cache_seqlens.size(0)
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
assert (
cache_seqlens.dim() == 1
and cache_seqlens.dtype == torch.int
and cache_seqlens.is_contiguous()
and cache_seqlens.device == torch.device("cpu")
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
cache_seqlens.dim(),
cache_seqlens.size(),
cache_seqlens.dtype,
cache_seqlens.is_contiguous(),
cache_seqlens.device,
)
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
return self.kvcache.clear_importance_all_layers(
block_table.data_ptr(),
cache_seqlens.data_ptr(),
batch_size,
max_block_num,
)
def get_cache_total_len(self):
return self.kvcache.get_cache_total_len()
def update_kvcache_q4(
self,
k_in: torch.Tensor,
k_scales: torch.Tensor,
v_in: torch.Tensor,
v_scales: torch.Tensor,
layer_id: int,
seq_offset: int | None = None,
seq_len: int | None = None,
block_table: torch.Tensor | None = None,
):
raise NotImplementedError
def update_kvcache_fp16(
self,
k_in: torch.Tensor,
v_in: torch.Tensor,
layer_idx,
block_table: torch.Tensor,
max_block_num,
past_len: torch.Tensor,
q_len,
):
batch_size = block_table.size(0)
return self.kvcache.get_kvcache_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
past_len.data_ptr(),
q_len
)
def get_kvcache_q4(
self,
k_in: torch.Tensor,
k_scales: torch.Tensor,
v_in: torch.Tensor,
v_scales: torch.Tensor,
layer_id: int,
seq_offset: int | None = None,
seq_len: int | None = None,
block_table: torch.Tensor | None = None,
):
raise NotImplementedError
def get_kvcache_fp16(
self,
k_in: torch.Tensor,
v_in: torch.Tensor,
layer_id: int,
layer_idx,
block_table: torch.Tensor,
max_block_num,
past_len: torch.Tensor,
):
batch_size = block_table.size(0)
return self.kvcache.get_kvcache_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
past_len.data_ptr(),
)
def get_and_update_kvcache_fp16(
self,
k_cache_cpu: torch.Tensor,
v_cache_cpu: torch.Tensor,
layer_idx,
block_table: torch.Tensor,
max_block_num,
past_len: torch.Tensor,
q_len,
):
batch_size = block_table.size(0)
return self.kvcache.get_and_update_kvcache_fp16(
k_cache_cpu.data_ptr(),
v_cache_cpu.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
past_len.data_ptr(),
q_len,
)
def update_importance(
self,
importance_cache: torch.Tensor,
layer_idx,
block_table: torch.Tensor,
max_block_num,
offset: torch.Tensor,
width,
):
batch_size = block_table.size(0)
return self.kvcache.update_importance(
importance_cache.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
offset.data_ptr(),
width,
)
# attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)
def get_attn_sparsity(
self,
q_in: torch.Tensor,
attn_sparsity: torch.Tensor,
layer_idx: int,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
block_table_origin: torch.Tensor,
cache_seqlens_origin: torch.Tensor,
generate_token_idx: int = 0,
topk: int | None = None,
local: int | None = None,
):
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
max_block_num_origin = block_table_origin.size(1)
q_len = q_in.size(1)
if topk is None or local is None or topk + local >= max_block_num:
topk = -1
local = -1
return self.kvcache.get_attn_sparsity(
q_in.data_ptr(),
attn_sparsity.data_ptr(),
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
block_table_origin.data_ptr(),
cache_seqlens_origin.data_ptr(),
max_block_num_origin,
topk,
local,
)
def attn_with_kvcache(
self,
q_in: torch.Tensor,
k_in: torch.Tensor,
v_in: torch.Tensor,
output: torch.Tensor,
attn_lse: torch.Tensor,
layer_idx: int,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
generate_token_idx: int = 0,
topk: int | None = None,
local: int | None = None,
):
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
q_len = q_in.size(1)
if topk is None or local is None or topk + local >= max_block_num:
topk = -1
local = -1
return self.kvcache.attn_with_kvcache(
q_in.data_ptr(),
k_in.data_ptr(),
v_in.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
topk,
local,
)
def get_all_kvcache_one_layer(
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int
):
return self.kvcache.get_all_kvcache_one_layer(
k_in.data_ptr(),
v_in.data_ptr(),
layer_id,
)
def get_importance(
self,
importance: torch.Tensor,
block_table: torch.Tensor,
):
raise NotImplementedError
def get_anchor(
self,
anchor: torch.Tensor,
block_table: torch.Tensor,
):
raise NotImplementedError
class CPUInfer:
cpu_infer = None
def __init__(self, cpu_infer:int = Config().cpu_infer):
if CPUInfer.cpu_infer is None:
CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
cpuinfer = None
def __init__(self, thread_num):
CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num)
def submit(self, task):
CPUInfer.cpuinfer.submit(task)
def submit_with_cuda_stream(self, current_cuda_stream, task):
CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task)
def sync(self):
CPUInfer.cpuinfer.sync()
def sync_with_cuda_stream(self, current_cuda_stream):
CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream)
def __getattribute__(self, __name: str) -> Any:
return CPUInfer.cpu_infer.__getattribute__(__name)
def __setattr__(self, __name: str, __value: Any) -> None:
return CPUInfer.cpu_infer.__setattr__(__name, __value)
\ No newline at end of file
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : Jianwei Dong
Date : 2024-08-26 23:25:24
Version : 1.0.0
LastEditors : Jianwei Dong
LastEditTime : 2024-08-26 23:25:24
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import torch
from transformers import AutoConfig
import sys, os
import logging
logger = logging.getLogger("dynamic_attention")
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
from flash_attn import flash_attn_func, flash_attn_with_kvcache
import math
import json
class DynamicScaledDotProductAttention:
remaining_length: int
def __init__(
self,
max_seq_len: int,
block_size: int,
config: AutoConfig,
device: torch.device,
local_windows_len: int,
topk: int,
threads_num: int,
anchor_type: str = "DYNAMIC",
kv_type: str = "FP16",
dense_layer_num: int = 0,
anchor_num: int = 1,
block_selection_mode: str = "SHARED",
layer_step: int = 1,
token_step: int = 1,
preselect_block: bool = False,
preselect_block_count: int = 96,
prefill_chunk_size: int = 20480,
use_attn_sparsity: bool = False,
):
# assert anchor_num == 1
# assert anchor_type == "DYNAMIC"
self.remaining_length = 0
valid_anchor_types = ["DYNAMIC", "FIXED", "BLOCK_MEAN", "BLOCK_MAX", "QUEST"]
assert anchor_type in valid_anchor_types
if anchor_type == "QUEST":
assert anchor_num == 2
elif anchor_type != "FIXED" and anchor_type != "DYNAMIC":
assert anchor_num == 1
valid_kv_types = ["FP16", "FP32", "Q4_0", "Q8_0"]
assert kv_type in valid_kv_types
if kv_type != "FP16" and kv_type != "FP32":
assert block_size % 32 == 0
valid_block_selection_modes = ["SHARED", "SEPARATE"] # individual
assert block_selection_mode in valid_block_selection_modes
self.max_seq_len = max_seq_len
self.block_num = max_seq_len // block_size
self.block_size = block_size
self.anchor_type = anchor_type
self.kv_type = kv_type
self.anchor_num = anchor_num
self.threads_num = threads_num
self.layer_step = layer_step
self.token_step = token_step
self.preselect_block = preselect_block
self.preselect_block_count = preselect_block_count
self.block_selection_mode = block_selection_mode
self.use_attn_sparsity = use_attn_sparsity
# model config
self.kv_head_num = config.num_key_value_heads
self.q_head_num = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.layer_num = config.num_hidden_layers
self.device = device
self.local_windows_len = local_windows_len
self.local_block_num = self.local_windows_len // self.block_size + 1
self.prefill_chunk_size = prefill_chunk_size
self.topk = topk
self.dense_layer_num = dense_layer_num
# self.dense_layer_num = 32
self.cache_key_states = torch.zeros(
(self.block_num, block_size, self.kv_head_num, self.head_dim),
device=device,
dtype=torch.float16,
)
self.cache_value_states = torch.zeros(
(self.block_num, block_size, self.kv_head_num, self.head_dim),
device=device,
dtype=torch.float16,
)
# [max_num_block, block_size, head_num]
self.cache_importance = torch.zeros(
(self.block_num, block_size, self.q_head_num),
device=device,
dtype=torch.float16,
)
# key_states: [bsz, q_len, kv_head_num, head_dim]
# value_states: [bsz, q_len, kv_head_num, head_dim]
# query_states: [bsz, q_len, q_head_num, head_dim]
self.q_in_cpu = torch.zeros(
(1, 1, self.q_head_num, self.head_dim),
device="cpu",
dtype=torch.float16,
pin_memory=True,
)
self.k_in_cpu = torch.zeros(
(1, 1, self.kv_head_num, self.head_dim),
device="cpu",
dtype=torch.float16,
pin_memory=True,
)
self.v_in_cpu = torch.zeros(
(1, 1, self.kv_head_num, self.head_dim),
device="cpu",
dtype=torch.float16,
pin_memory=True,
)
self.cache_seqlens_cpu = torch.empty(
(1,), device="cpu", dtype=torch.int32, pin_memory=True
)
self.cache_seqlens_cuda = torch.empty((1,), device=device, dtype=torch.int32)
self.prefix_block_table = torch.arange(
self.block_num, device="cpu", dtype=torch.int32, pin_memory=True
).view(1, -1)
self.block_table_cpu = torch.arange(
self.block_num, device="cpu", dtype=torch.int32, pin_memory=True
).view(1, -1)
# assert (
# self.local_windows_len // self.block_size + 1 + self.preselect_block_count
# <= self.block_num
# )
self.output_cpu = torch.empty(
(1, 1, self.q_head_num, self.head_dim),
device="cpu",
dtype=torch.float16,
pin_memory=True,
)
self.lse_cpu = torch.empty(
(1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True
)
self.output_cuda = torch.empty(
(1, 1, self.q_head_num, self.head_dim), device=device, dtype=torch.float16
)
self.attn_sparsity = torch.zeros(
(1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True
)
if preselect_block == True:
self.preselect_block_table = torch.zeros(
self.layer_num,
self.preselect_block_count,
device=device,
dtype=torch.int32,
)
self.preselect_block_num = 0 # block_num before preselect
self.evict_tokens = 0
self.cpu_infer = CPUInfer(threads_num)
self.local_thread = CPUInferKVCache(
self.layer_num,
self.kv_head_num,
self.q_head_num,
self.head_dim,
self.block_size,
anchor_num=self.anchor_num,
anchor_type=anchor_type,
kv_type=self.kv_type,
retrieval_type=self.block_selection_mode,
layer_step=self.layer_step,
token_step=self.token_step,
layer_offset=self.dense_layer_num % self.layer_step,
max_batch_size=1,
max_block_num=self.block_num,
max_thread_num=self.threads_num,
)
print(
f"local_windows_len: {local_windows_len}, topk: {topk}, dense_layer_num: {dense_layer_num}, kv_type: {self.kv_type}, anchor_type: {self.anchor_type}, preselect_block: {self.preselect_block}, preselect_block_count: {self.preselect_block_count}, token_step: {self.token_step}, layer_step: {self.layer_step}"
)
self.shape_mask = (
self.q_head_num,
self.block_size,
self.block_size,
)
mask = torch.zeros(
self.shape_mask, dtype=torch.uint8, device=device
).contiguous()
elm_idx = torch.arange(self.block_size, device=device)
for i in range(mask.size(-2)):
idx = i + mask.size(-1) - mask.size(-2) - elm_idx
idx = idx[idx >= 0]
mask[..., i, idx] = 1
self.tril_mask = mask
self.triu_mask = mask ^ 1
self.generate_token_idx = 0
def get_attn_score_one_block(
self,
batch_idx: int,
max_block_num: int,
query: torch.Tensor,
key: torch.Tensor,
offset: int,
width: int,
mask_mode: str | None = None,
use_softmax: bool = True,
):
n_rep = self.q_head_num // self.kv_head_num
importance = self.cache_importance.view(-1, self.q_head_num)
importance = importance.narrow(0, batch_idx * max_block_num + offset, width)
n_gqa_ = self.q_head_num // self.kv_head_num
for head_idx in range(self.q_head_num):
key_item = key[..., head_idx // n_gqa_, :].view(key.size(0), -1)
qk = torch.einsum(
"qd,kd->qk", query[:,head_idx,:], key_item
) # (num_attention_heads, len_q, len_k)
if mask_mode == "tril":
mask = self.tril_mask
mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]
qk = qk * mask
elif mask_mode == "triu":
mask = self.triu_mask
mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]
qk = qk * mask
if use_softmax:
qk = torch.nn.functional.softmax(
qk / math.sqrt(self.head_dim), dim=-1, dtype=torch.float32
).to(torch.float16)
qk = torch.sum(qk, dim=-2)
importance[...,head_idx] += qk
def get_preselect_block_table_and_attn_score(
self,
layer_idx: int,
batch_size: int,
offset: torch.Tensor,
width: int,
query: torch.Tensor,
key: torch.Tensor,
union_with_last_layer: bool = True,
):
max_seqs_len = offset.max().item() + width
max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
for batch_idx in range(batch_size):
query_cur = query[batch_idx][-128:]
self.get_attn_score_one_block(
batch_idx,
max_block_num,
query_cur,
key[batch_idx][: offset[batch_idx].item() + width],
0,
offset[batch_idx].item() + width,
mask_mode=None,
)
if self.preselect_block:
self.prefill_block_num = max(
0, max_block_num - self.local_windows_len // self.block_size
)
self.evict_tokens = (
max(self.prefill_block_num - self.preselect_block_count, 0)
* self.block_size
)
if self.prefill_block_num != 0:
importance_cache = self.cache_importance.narrow(
0, 0, self.prefill_block_num * batch_size
).view(
batch_size, self.prefill_block_num, self.block_size, self.q_head_num
)
importance_r = importance_cache[:, 1:, : self.block_size // 4]
pad_r = torch.zeros_like(importance_r[:, :1])
importance_r = torch.cat((importance_r, pad_r), dim=1)
importance_l = importance_cache[:, :-1, -self.block_size // 4 :]
pad_l = torch.zeros_like(importance_l[:, :1])
importance_l = torch.cat((pad_l, importance_l), dim=1)
importance = torch.cat(
(importance_l, importance_cache, importance_r), dim=2
)
importance = importance.mean(dim=-1)
importance = importance.mean(dim=-1)
# importance: (batch_size, max_block_num)
topk = min(self.preselect_block_count, self.prefill_block_num)
values, indices = torch.topk(
importance,
k=topk,
dim=1,
)
self.preselect_block_table[
layer_idx : layer_idx + 1,
:topk,
].copy_(indices)
if union_with_last_layer and layer_idx == 31:
for tmp_layer_idx in range(self.layer_num - 1):
for i in range(1, min(topk, 6)):
x = self.preselect_block_table[-1, i]
if x not in self.preselect_block_table[tmp_layer_idx]:
self.preselect_block_table[tmp_layer_idx, topk - i] = x
if self.anchor_type == "DYNAMIC":
importance_cache = self.cache_importance.narrow(
0, 0, max_block_num * batch_size
).view(batch_size, max_block_num * self.block_size, self.q_head_num)
importance_cache_cpu = torch.empty_like(
importance_cache, device="cpu", pin_memory=True
)
importance_cache_cpu.copy_(importance_cache)
block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu")
offset_cpu = offset.contiguous().to("cpu")
self.cpu_infer.submit(
self.local_thread.update_importance(
importance_cache_cpu,
layer_idx,
block_table_cpu,
max_block_num,
offset_cpu,
width,
)
)
self.cpu_infer.sync()
importance_cache = self.cache_importance.narrow(
0, 0, max_block_num * batch_size
).view(batch_size, max_block_num * self.block_size, self.q_head_num)
importance_cache.zero_()
# key: [bsz, past_len, head_num, head_dim] float16
# query: [bsz, q_len, q_head_num, head_dim] float16
def get_attn_score(
self,
layer_idx: int,
batch_size: int,
offset: torch.Tensor,
width: int,
query: torch.Tensor,
key: torch.Tensor,
):
max_seqs_len = offset.max().item() + width
max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
for batch_idx in range(batch_size):
for idx in range(width // self.block_size):
offset_cur = idx * self.block_size
query_cur = query[batch_idx, offset_cur : offset_cur + self.block_size]
self.get_attn_score_one_block(
batch_idx,
max_block_num,
query_cur,
key[
batch_idx,
offset[batch_idx]
+ offset_cur : offset[batch_idx]
+ offset_cur
+ self.block_size,
],
offset[batch_idx].item() + offset_cur,
self.block_size,
mask_mode="tril",
use_softmax=False,
)
offset_key = (
offset[batch_idx].item()
+ idx * self.block_size
- self.local_windows_len
)
if offset_key >= 0:
self.get_attn_score_one_block(
batch_idx,
max_block_num,
query_cur,
key[batch_idx, offset_key : offset_key + self.block_size],
offset_key,
self.block_size,
mask_mode="triu",
use_softmax=False,
)
offset_key = max(0, offset_key + self.block_size)
width_key = (
offset[batch_idx].item() + idx * self.block_size - offset_key
)
if width_key > 0:
self.get_attn_score_one_block(
batch_idx,
max_block_num,
query_cur,
key[batch_idx, offset_key : offset_key + width_key],
offset_key,
width_key,
mask_mode=None,
use_softmax=False,
)
importance_cache = self.cache_importance.narrow(
0, 0, max_block_num * batch_size
).view(batch_size, max_block_num * self.block_size, self.q_head_num)
importance_cache_cpu = torch.empty_like(
importance_cache, device="cpu", pin_memory=True
)
importance_cache_cpu.copy_(importance_cache)
block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu")
offset_cpu = offset.contiguous().to("cpu")
self.cpu_infer.submit(
self.local_thread.update_importance(
importance_cache_cpu,
layer_idx,
block_table_cpu,
max_block_num,
offset_cpu,
width,
)
)
self.cpu_infer.sync()
importance_cache.zero_()
# key: [bsz, q_len, head_num, head_dim] float16
# value: [bsz, q_len, head_num, head_dim] float16
def swap_in_and_swap_out(self, layer_idx, past_len, q_len, key, value):
batch_size = 1
max_seqs_len = past_len.max().item() + q_len
max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size
k_cache = self.cache_key_states.narrow(0, 0, max_block_num * batch_size).view(
batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim
)
v_cache = self.cache_value_states.narrow(0, 0, max_block_num * batch_size).view(
batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim
)
for batch_idx in range(batch_size):
offset = past_len[batch_idx]
width = q_len
k_cache[batch_idx][offset : offset + width].copy_(
key[batch_idx].view(-1, self.kv_head_num, self.head_dim)
)
v_cache[batch_idx][offset : offset + width].copy_(
value[batch_idx].view(-1, self.kv_head_num, self.head_dim)
)
k_cache_cpu = torch.empty_like(k_cache, device="cpu", pin_memory=True)
v_cache_cpu = torch.empty_like(v_cache, device="cpu", pin_memory=True)
k_cache_cpu.copy_(k_cache)
v_cache_cpu.copy_(v_cache)
cur_block_num = (
q_len + past_len[0].item() + self.block_size - 1
) // self.block_size
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
past_len_cpu = past_len.contiguous().to("cpu")
self.cpu_infer.submit(
self.local_thread.get_and_update_kvcache_fp16(
k_cache_cpu,
v_cache_cpu,
layer_idx,
block_table_cpu,
max_block_num,
past_len_cpu,
q_len,
)
)
self.cpu_infer.sync()
k_cache.copy_(k_cache_cpu)
v_cache.copy_(v_cache_cpu)
return k_cache, v_cache
def calc_anchor(self, cache_seqlens: int):
cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
cache_seqlens_cpu = torch.tensor(
[cache_seqlens], device="cpu", dtype=torch.int32
)
self.cpu_infer.submit(
self.local_thread.calc_anchor_all_layers(
block_table_cpu,
cache_seqlens_cpu,
)
)
self.cpu_infer.sync()
def clear_importance(self, cache_seqlens: int):
print(f"clear importance: {cache_seqlens}")
cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
cache_seqlens_cpu = torch.tensor(
[cache_seqlens], device="cpu", dtype=torch.int32
)
self.cpu_infer.submit(
self.local_thread.clear_importance_all_layers(
block_table_cpu,
cache_seqlens_cpu,
)
)
self.cpu_infer.sync()
def clear_kvcache(self, cache_seqlens: int):
cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size
block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu")
cache_seqlens_cpu = torch.tensor(
[cache_seqlens], device="cpu", dtype=torch.int32
)
self.cpu_infer.submit(
self.local_thread.clear_kvcache_all_layers(
block_table_cpu,
cache_seqlens_cpu,
)
)
self.cpu_infer.sync()
def get_attn_sparsity(
self,
q_in: torch.Tensor,
layer_idx: int,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
block_table_origin: torch.Tensor,
cache_seqlens_origin: torch.Tensor,
generate_token_idx: int = 0,
topk: int | None = None,
local: int | None = None,
output_path: str = "./attn_sparsity.json",
):
self.attn_sparsity.zero_()
self.pcinfer.submit(
self.local_thread.get_attn_sparsity(
q_in,
self.attn_sparsity,
layer_idx,
block_table,
cache_seqlens,
block_table_origin,
cache_seqlens_origin,
generate_token_idx,
topk,
local,
)
)
self.cpu_infer.sync()
with open(output_path, "a") as file:
for head_idx in range(self.q_head_num):
sparsity = self.attn_sparsity[0][0][head_idx].item()
json_obj = {
"token_idx": generate_token_idx,
"layer_idx": layer_idx,
"head_idx": head_idx,
"sparsity": sparsity,
}
json.dump(json_obj, file)
file.write("\n")
def apply(
self,
layer_idx: int,
bsz: int,
past_len: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
mode: str = "prefill",
generate_token_idx: int = -1,
):
# key_states: [bsz, q_len, kv_head_num, head_dim]
# value_states: [bsz, q_len, kv_head_num, head_dim]
# query_states: [bsz, q_len, q_head_num, head_dim]
assert query_states.dtype == torch.float16
assert key_states.dtype == torch.float16
assert value_states.dtype == torch.float16
assert key_states.size(2) == self.kv_head_num
assert value_states.size(2) == self.kv_head_num
assert query_states.size(2) == self.q_head_num
q_len = query_states.size(1)
batch_size = query_states.size(0)
self.cache_seqlens_cuda.fill_(past_len)
last_chunk = False
if self.remaining_length <= self.prefill_chunk_size and q_len != 1:
last_chunk = True
device = query_states.device
if layer_idx == 0:
if q_len == 1:
self.generate_token_idx += 1
elif last_chunk:
self.generate_token_idx = -1
if mode == "prefill":
key, value = self.swap_in_and_swap_out(
layer_idx,
self.cache_seqlens_cuda,
q_len,
key_states,
value_states,
)
if last_chunk and (self.anchor_type == "DYNAMIC" or self.preselect_block):
self.get_preselect_block_table_and_attn_score(
layer_idx,
bsz,
self.cache_seqlens_cuda,
q_len,
query_states,
key,
)
output = flash_attn_with_kvcache(
q=query_states,
k_cache=key,
v_cache=value,
cache_seqlens=self.cache_seqlens_cuda + q_len,
causal=True,
)
return output.transpose(1, 2)
elif mode == "generate":
assert self.generate_token_idx >= 0
self.q_in_cpu.copy_(query_states, non_blocking=True)
self.k_in_cpu.copy_(key_states, non_blocking=True)
self.v_in_cpu.copy_(value_states, non_blocking=True)
self.cache_seqlens_cpu.copy_(self.cache_seqlens_cuda, non_blocking=True)
# print(layer_idx)
if layer_idx < self.dense_layer_num:
self.block_table_cpu.copy_(self.prefix_block_table, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(
torch.cuda.current_stream("cuda").cuda_stream,
self.local_thread.attn_with_kvcache(
q_in=self.q_in_cpu,
k_in=self.k_in_cpu,
v_in=self.v_in_cpu,
output=self.output_cpu,
attn_lse=self.lse_cpu,
layer_idx=layer_idx,
block_table=self.block_table_cpu,
cache_seqlens=self.cache_seqlens_cpu,
),
)
else:
if self.preselect_block:
self.cache_seqlens_cpu.copy_(
self.cache_seqlens_cuda - self.evict_tokens, non_blocking=True
)
if self.preselect_block_count < self.prefill_block_num:
self.block_table_cpu[:, : self.preselect_block_count].copy_(
self.preselect_block_table[layer_idx : layer_idx + 1],
non_blocking=True,
)
self.block_table_cpu[
:,
self.preselect_block_count : self.preselect_block_count
+ self.local_block_num,
].copy_(
self.prefix_block_table[
:,
self.prefill_block_num : self.prefill_block_num
+ self.local_block_num,
],
non_blocking=True,
)
# print("submit_with_cuda_stream")
self.cpu_infer.submit_with_cuda_stream(
torch.cuda.current_stream("cuda").cuda_stream,
self.local_thread.attn_with_kvcache(
q_in=self.q_in_cpu,
k_in=self.k_in_cpu,
v_in=self.v_in_cpu,
output=self.output_cpu,
attn_lse=self.lse_cpu,
layer_idx=layer_idx,
generate_token_idx=self.generate_token_idx,
block_table=self.block_table_cpu,
cache_seqlens=self.cache_seqlens_cpu,
topk=(
self.topk
if self.topk <= self.preselect_block_count
else None
),
local=self.local_windows_len // self.block_size,
),
)
# print("submit_with_cuda_stream enqueue\n")
else:
self.block_table_cpu.copy_(
self.prefix_block_table, non_blocking=True
)
self.cpu_infer.submit_with_cuda_stream(
torch.cuda.current_stream("cuda").cuda_stream,
self.local_thread.attn_with_kvcache(
q_in=self.q_in_cpu,
k_in=self.k_in_cpu,
v_in=self.v_in_cpu,
output=self.output_cpu,
attn_lse=self.lse_cpu,
layer_idx=layer_idx,
generate_token_idx=self.generate_token_idx,
block_table=self.block_table_cpu,
cache_seqlens=self.cache_seqlens_cpu,
topk=self.topk,
local=self.local_windows_len // self.block_size,
),
)
self.cpu_infer.sync_with_cuda_stream(
torch.cuda.current_stream("cuda").cuda_stream
)
# print("submit_with_cuda_stream finished\n")
self.output_cuda.copy_(self.output_cpu, non_blocking=True)
return self.output_cuda.transpose(1, 2)
def save(self, path: str, length: int):
cur_block_num = (length + self.block_size - 1) // self.block_size
block_table_cpu = self.prefix_block_table[0, :cur_block_num].to("cpu")
cache_seqlens_cpu = torch.tensor([length], device="cpu", dtype=torch.int32)
self.cpu_infer.submit(
self.local_thread.dump_kvcache(
block_table_cpu,
cache_seqlens_cpu,
path,
)
)
self.cpu_infer.sync()
def load(self, path: str, length: int):
self.cpu_infer.submit(
self.local_thread.load_kvcache(
path,
)
)
self.cpu_infer.sync()
......@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-25 11:25:24
Version : 0.1.0
LastEditors : Azure
LastEditTime : 2024-08-15 02:36:29
LastEditTime : 2024-08-27 03:50:23
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
......@@ -436,7 +436,7 @@ class KExpertsTorch(KExpertsBase):
final_hidden_states.index_add_(0, top_x, current_hidden_states)
return final_hidden_states.to(org_dtype, device=org_device)
return final_hidden_states.to(dtype=org_dtype, device=org_device)
EXPERTS_MAP = {
"KExpertsCPU": KExpertsCPU,
......
#!/usr/bin/env python
# coding=utf-8
'''
"""
Description :
Author : Azure-Tang
Date : 2024-07-25 11:25:24
Version : 1.0.0
LastEditors : Azure
LastEditTime : 2024-08-14 14:53:05
LastEditTime : 2024-08-27 07:29:04
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
"""
import inspect
import math
......@@ -19,7 +19,10 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention
from ktransformers.server.config.config import Config
import os
import yaml
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import (
......@@ -40,19 +43,35 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock, Qwen2MoeMLP, Qwen2MoeDecoderLayer
from ktransformers.models.modeling_deepseek import BaseModelOutputWithPast, DeepseekV2DecoderLayer, DeepseekV2MoE
from ktransformers.models.modeling_qwen2_moe import (
Qwen2MoeSparseMoeBlock,
Qwen2MoeMLP,
Qwen2MoeDecoderLayer,
)
from ktransformers.models.modeling_deepseek import (
BaseModelOutputWithPast,
DeepseekV2DecoderLayer,
DeepseekV2MoE,
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
_flash_supports_window_size = "window_size" in list(
inspect.signature(flash_attn_func).parameters
)
logger = logging.get_logger(__name__)
......@@ -151,6 +170,7 @@ QWEN2MOE_INPUTS_DOCSTRING = r"""
the complete sequence length.
"""
@add_start_docstrings(
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING,
......@@ -162,18 +182,21 @@ class KQwen2MoeModel(BaseInjectedModule):
Args:
config: Qwen2MoeConfig
"""
def __init__(
self,
key: str,
gguf_loader : GGUFLoader,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
......@@ -192,29 +215,47 @@ class KQwen2MoeModel(BaseInjectedModule):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
per_layer_prefill_intput_threshold: int | None = None, # if None or 0, close per-layer prefill
per_layer_prefill_intput_threshold: (
int | None
) = None, # if None or 0, close per-layer prefill
) -> Union[Tuple, MoeModelOutputWithPast]:
# print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')
if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
if per_layer_prefill_intput_threshold is None:
per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
per_layer_prefill_flag = False
seq_lenth = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
if per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth:
seq_lenth = (
inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
)
if (
per_layer_prefill_intput_threshold
and per_layer_prefill_intput_threshold < seq_lenth
):
per_layer_prefill_flag = True
for layer in self.layers:
self.load_layer_to(layer, InferenceState.UNLOAD)
else:
pass
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
......@@ -243,15 +284,23 @@ class KQwen2MoeModel(BaseInjectedModule):
inputs_embeds = inputs_embeds.to("cuda")
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
......@@ -263,7 +312,7 @@ class KQwen2MoeModel(BaseInjectedModule):
next_decoder_cache = None
for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
......@@ -271,11 +320,25 @@ class KQwen2MoeModel(BaseInjectedModule):
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
hidden_states = hidden_states.to(
self.transfer_map[i], non_blocking=True
)
causal_mask = (
causal_mask.to(self.transfer_map[i], non_blocking=True)
if causal_mask is not None
else None
)
position_ids = (
position_ids.to(self.transfer_map[i], non_blocking=True)
if position_ids is not None
else None
)
cache_position = (
cache_position.to(self.transfer_map[i], non_blocking=True)
if cache_position is not None
else None
)
if output_hidden_states:
all_hidden_states += (hidden_states,)
......@@ -323,7 +386,6 @@ class KQwen2MoeModel(BaseInjectedModule):
hidden_states = self.norm(hidden_states)
if per_layer_prefill_flag:
per_layer_prefill_flag = False
for layer in self.layers:
......@@ -333,12 +395,22 @@ class KQwen2MoeModel(BaseInjectedModule):
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_router_logits,
]
if v is not None
)
return MoeModelOutputWithPast(
......@@ -349,11 +421,13 @@ class KQwen2MoeModel(BaseInjectedModule):
router_logits=all_router_logits,
)
def load_layer_to(self, layer:Qwen2MoeDecoderLayer, target: InferenceState):
assert isinstance(layer, Qwen2MoeDecoderLayer), "module should be nn.ModuleList of decoder layers"
def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState):
assert isinstance(
layer, Qwen2MoeDecoderLayer
), "module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
# attn
layer.self_attn.q_proj.set_inference_mode(target)
......@@ -458,18 +532,21 @@ class KDeepseekV2Model(BaseInjectedModule):
Args:
config: DeepseekV2Config
"""
def __init__(
self,
key: str,
gguf_loader : GGUFLoader,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
......@@ -487,15 +564,23 @@ class KDeepseekV2Model(BaseInjectedModule):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
per_layer_prefill_intput_threshold: int | None = None, # if None, no per-layer prefill
per_layer_prefill_intput_threshold: (
int | None
) = None, # if None, no per-layer prefill
) -> Union[Tuple, BaseModelOutputWithPast]:
if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
if per_layer_prefill_intput_threshold is None:
per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
per_layer_prefill_flag = False
seq_lenth = inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
if per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth:
seq_lenth = (
inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
)
if (
per_layer_prefill_intput_threshold
and per_layer_prefill_intput_threshold < seq_lenth
):
per_layer_prefill_flag = True
for layer in self.layers:
self.load_layer_to(layer, InferenceState.UNLOAD)
self.load_layer_to(layer, InferenceState.UNLOAD)
torch.cuda.empty_cache()
else:
pass
......@@ -542,9 +627,13 @@ class KDeepseekV2Model(BaseInjectedModule):
past_key_values_length = past_key_values.get_usable_length(seq_length)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
......@@ -556,15 +645,17 @@ class KDeepseekV2Model(BaseInjectedModule):
inputs_embeds = self.embed_tokens(input_ids)
input_ids = input_ids.to(org_device)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
if per_layer_prefill_flag:
causal_mask = None
else:
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
if per_layer_prefill_flag:
print(f'Total length of input_ids: {hidden_states.size(1)}')
print(f"Total length of input_ids: {hidden_states.size(1)}")
# decoder layers
all_hidden_states = () if output_hidden_states else None
......@@ -576,7 +667,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f = 0
for i, decoder_layer in enumerate(self.layers):
if self.transfer_map is not None and i in self.transfer_map:
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
if cur_device not in self.stream_device_map:
......@@ -584,10 +675,24 @@ class KDeepseekV2Model(BaseInjectedModule):
torch.cuda.set_device(cur_device)
self.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.stream_device_map[cur_device])
hidden_states = hidden_states.to(self.transfer_map[i], non_blocking = True)
causal_mask = causal_mask.to(self.transfer_map[i], non_blocking = True) if causal_mask is not None else None
position_ids = position_ids.to(self.transfer_map[i], non_blocking = True) if position_ids is not None else None
cache_position = cache_position.to(self.transfer_map[i], non_blocking = True) if cache_position is not None else None
hidden_states = hidden_states.to(
self.transfer_map[i], non_blocking=True
)
causal_mask = (
causal_mask.to(self.transfer_map[i], non_blocking=True)
if causal_mask is not None
else None
)
position_ids = (
position_ids.to(self.transfer_map[i], non_blocking=True)
if position_ids is not None
else None
)
cache_position = (
cache_position.to(self.transfer_map[i], non_blocking=True)
if cache_position is not None
else None
)
if output_hidden_states:
all_hidden_states += (hidden_states,)
......@@ -622,12 +727,12 @@ class KDeepseekV2Model(BaseInjectedModule):
t5 = time.time()
if per_layer_prefill_flag:
# print(f"to cpu")
self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
torch.cuda.empty_cache()
t6 = time.time()
t_gpu += t4-t3
t_cpu += t6-t5
t_f += t5-t4
t_gpu += t4 - t3
t_cpu += t6 - t5
t_f += t5 - t4
hidden_states = layer_outputs[0]
......@@ -648,7 +753,9 @@ class KDeepseekV2Model(BaseInjectedModule):
torch.cuda.empty_cache()
t7 = time.time()
print(f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}")
print(
f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}"
)
# add hidden states from the last decoder layer
if output_hidden_states:
......@@ -674,16 +781,18 @@ class KDeepseekV2Model(BaseInjectedModule):
attentions=all_self_attns,
)
def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
assert isinstance(layer, DeepseekV2DecoderLayer), "module should be nn.ModuleList of decoder layers"
def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
assert isinstance(
layer, DeepseekV2DecoderLayer
), "module should be nn.ModuleList of decoder layers"
# TODO Support restore to original device, not only cuda
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
device = "cpu" if target == InferenceState.UNLOAD else "cuda"
# TODO Support DFS to auto use {to, set_inference_mode} according to the module type
# attn
layer.self_attn.to(device) #
layer.self_attn.to(device) #
# mlp
if isinstance(layer.mlp, DeepseekV2MoE):
......@@ -702,3 +811,526 @@ class KDeepseekV2Model(BaseInjectedModule):
# layer norm
layer.input_layernorm.to(device)
layer.post_attention_layernorm.to(device)
LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class KLlamaModel(BaseInjectedModule):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
dynamic_sdpa = None
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill
transfer_map: dict = None,
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
)
self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
self.transfer_map = transfer_map
self.stream_device_map = dict()
user_path: str = os.path.expanduser('~')
localstore_path: str = os.path.join(user_path,'.ktransformers')
config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
with open(config_path,"r") as file:
config_yaml = yaml.safe_load(file.read())
self.long_context_config = config_yaml.get("long_context")
self.ext_config = config_yaml.get("ext")
KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention(
max_seq_len=self.long_context_config["max_seq_len"],
block_size=self.long_context_config["block_size"],
config=config,
device=torch.device("cuda"),
local_windows_len=self.long_context_config["local_windows_len"],
topk=self.long_context_config["second_select_num"],
threads_num=self.ext_config["cpu_infer"],
anchor_type=self.long_context_config["anchor_type"],
kv_type=self.long_context_config["kv_type"],
dense_layer_num=self.long_context_config["dense_layer_num"],
anchor_num=self.long_context_config["anchor_num"],
preselect_block=self.long_context_config["preselect_block"],
block_selection_mode=self.long_context_config["head_select_mode"],
preselect_block_count=self.long_context_config["preselect_block_count"],
layer_step=self.long_context_config["layer_step"],
token_step=self.long_context_config["token_step"],
prefill_chunk_size=self.long_context_config["chunk_size"],
use_attn_sparsity=False,
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device="cuda",
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = None
chunck_size = self.long_context_config["chunk_size"]
cur_idx = 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.to("cpu"))
q_len = cache_position.size(0)
# generate
if q_len == 1:
x = inputs_embeds[:, -1:, :]
position_ids = position_ids[:, -1:]
return self.forward_chunk(
x,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_hidden_states,
return_dict,
)
elif q_len <= chunck_size:
inputs_embeds = inputs_embeds.to('cuda')
output = self.forward_chunk(
inputs_embeds,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_hidden_states,
return_dict,
)
KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
return output
cur_idx = 0
assert (
output_attentions == False
), "output_attentions is not supported when using chunked attention"
attn_output = None
# prefill
KLlamaModel.dynamic_sdpa.remaining_length = q_len
while cur_idx < q_len:
print(f'current prefill length: {cur_idx}')
chunk_mask = None
if inputs_embeds.device.type == 'cpu':
tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to("cuda")
else:
tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)]
output_with_past = self.forward_chunk(
tmp_inputs_embeds,
chunk_mask,
position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)],
past_key_values,
output_attentions,
use_cache,
cache_position[cur_idx : min(cur_idx + chunck_size, q_len)],
)
cur_output = output_with_past.last_hidden_state
KLlamaModel.dynamic_sdpa.remaining_length -= (
min(cur_idx + chunck_size, q_len) - cur_idx
)
cur_idx += chunck_size
# if attn_output is None:
attn_output = cur_output
# else:
# attn_output = torch.cat((attn_output, cur_output), dim=-2)
KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
return BaseModelOutputWithPast(last_hidden_state=attn_output)
def forward_chunk(
self,
inputs_embeds,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_legacy_cache = False
if use_cache and not isinstance(
past_key_values, Cache
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError(
"Custom 4D attention mask should be passed in inverted form with max==0`"
)
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(
input_tensor.shape[0], 1, -1, -1
)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
......@@ -225,4 +225,4 @@
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
\ No newline at end of file
prefill_device: "cuda:3"
......@@ -123,4 +123,4 @@
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
\ No newline at end of file
prefill_device: "cuda:1"
......@@ -6,7 +6,7 @@
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
......@@ -41,6 +41,12 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 2000 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
......
......@@ -123,4 +123,4 @@
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
\ No newline at end of file
prefill_device: "cuda:1"
- match:
class: ktransformers.models.modeling_llama.LlamaRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV2
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_llama.LlamaModel
replace:
class: ktransformers.operators.models.KLlamaModel
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KLlamaAttention
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
......@@ -109,4 +109,4 @@
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
\ No newline at end of file
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
......@@ -54,4 +61,4 @@
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
prefill_device: "cuda"
......@@ -5,10 +5,11 @@ Description :
Author : unicornchan
Date : 2024-06-11 16:35:42
Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2024-07-27 01:55:42
LastEditors : WuHao
LastEditTime : 2024-08-12 06:31:14
'''
import os
import shutil
import yaml
from ktransformers.server.config.singleton import Singleton
......@@ -30,10 +31,18 @@ class Config(metaclass=Singleton):
os.path.dirname(os.path.dirname(__file__)))
config_yaml: str = os.path.join(
base_path, "configs", Config.CONFIG_FILE_NAME)
user_path: str = os.path.expanduser('~')
localstore_path: str = os.path.join(user_path,'.ktransformers')
config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
if not os.path.exists(config_yaml):
print(f"Can't find config file, {config_yaml}")
exit(-1)
with open(config_yaml, 'r', encoding="utf-8") as fp:
if not os.path.exists(localstore_path):
os.mkdir(localstore_path)
if not os.path.exists(config_path):
shutil.copyfile(config_yaml,config_path)
with open(config_path, 'r', encoding="utf-8") as fp:
config = yaml.safe_load(fp)
return config
......@@ -51,6 +60,8 @@ class Config(metaclass=Singleton):
cfg = Config.load()
self.base_path = os.path.dirname(
os.path.dirname(os.path.dirname(__file__)))
self.user_path: str = os.path.expanduser('~')
self.localstore_path: str = os.path.join(self.user_path,'.ktransformers')
# log configs
self.log_dir = os.path.join(self.base_path, Config.to_path(cfg["log"]["dir"]))
self.log_file = cfg["log"]["file"]
......@@ -83,11 +94,20 @@ class Config(metaclass=Singleton):
self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: str = self.model.get("gguf_path", "")
self.model_cache_lens = self.model.get("cache_lens")
# web config
self.web: dict = cfg.get("web", {})
self.web_cross_domain: bool = self.web.get("open_cross_domain", True)
self.mount_web: bool = self.web.get("mount", False)
self.ext: dict = cfg.get("ext", {})
self.cpu_infer = self.ext.get("cpu_infer", 10)
#file config
self.local_store_configs: dict = cfg.get("local_store",{})
self.file_upload_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("file_upload_dir",""))
self.assistant_store_dir: str = os.path.join(self.localstore_path,self.local_store_configs.get("assistant_store_dir",""))
#long context config
self.long_context_config: dict = cfg.get("long_context",{})
\ No newline at end of file
......@@ -46,7 +46,8 @@ class CUDAGraphRunner:
capture_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.set_device(main_device)
torch.cuda.set_stream(capture_stream)
past_key_values.change_seq_length(-1)
if past_key_values != None:
past_key_values.change_seq_length(-1)
torch.cuda.synchronize(self.main_device)
#self.graph.debug_dump("cuda_graph_hooked.dot")
......
......@@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : kkk1nak0
LastEditTime : 2024-08-12 07:21:55
LastEditTime : 2024-08-14 08:20:45
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
Copyright (c) 2023-2024 The ggml authors
Copyright (c) 2024 Thomas Germer
......@@ -294,7 +294,6 @@ class GGUFLoader:
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
......
......@@ -84,7 +84,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
else:
module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal'):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
......@@ -110,7 +111,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
past_key_values.change_seq_length(1)
if past_key_values != None:
past_key_values.change_seq_length(1)
for device in all_cuda_device:
torch.cuda.synchronize(device)
#print(logits)
......@@ -125,18 +127,26 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch.cuda.set_device(torch_device)
with torch.no_grad():
stream = TextStreamer(tokenizer)
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
if mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
else:
past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
past_key_values.cur_idx=cache_position
if past_key_values != None:
past_key_values.cur_idx=cache_position
start_time = time.time()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if mode == "long_context":
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
......@@ -184,7 +194,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens.append(next_token.int())
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id:
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
......
......@@ -27,7 +27,8 @@ dependencies = [
"wheel",
"colorlog",
"build",
"fire"
"fire",
"protobuf"
]
requires-python = ">=3.10"
......
......@@ -3,4 +3,5 @@ transformers
numpy
torch>=2.3.0
packaging
cpufeature
\ No newline at end of file
cpufeature
protobuf
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment