Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
038bc308
Commit
038bc308
authored
Feb 17, 2025
by
Atream
Browse files
fix precision bug imported by position_ids in 0.2.0
parent
b8452462
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
471 additions
and
45 deletions
+471
-45
ktransformers/local_chat.py
ktransformers/local_chat.py
+11
-3
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+0
-2
ktransformers/operators/RoPE.py
ktransformers/operators/RoPE.py
+9
-8
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+189
-22
ktransformers/operators/base_operator.py
ktransformers/operators/base_operator.py
+5
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+5
-1
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+240
-0
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+2
-2
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+1
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+9
-4
No files found.
ktransformers/local_chat.py
View file @
038bc308
...
...
@@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
...
@@ -170,8 +171,15 @@ def local_chat(
torch
.
set_default_dtype
(
torch
.
bfloat16
)
# TODO: Remove this, replace dtype using config
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
,
force_think
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
)
...
...
ktransformers/models/custom_cache.py
View file @
038bc308
...
...
@@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache):
page_idx
=
cache_position
//
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out
[
page_idx
,
page_offset
,
:,
:
self
.
kv_lora_rank
]
=
key_states
k_out
[
page_idx
,
page_offset
,
:,
self
.
kv_lora_rank
:]
=
value_states
return
k_out
,
self
.
page_table_list
[
layer_idx
]
...
...
ktransformers/operators/RoPE.py
View file @
038bc308
...
...
@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
orig_module
.
max_position_embeddings
,
orig_module
.
base
...
...
@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
...
...
@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
...
...
@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
...
...
@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
# **kwargs,
# ):
# BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
# self, key, gguf_loader, config, orig_module,
prefill_device,
generate_device, **kwargs
# )
# self.generate_device = generate_device
# self.prefill_device = prefill_device
...
...
@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
...
...
@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
...
...
ktransformers/operators/attention.py
View file @
038bc308
...
...
@@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_with_kvcache
,
flash_attn_func
from
flash_attn
import
flash_attn_func
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
,
attention_ref
logger
=
logging
.
getLogger
(
"attention"
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
...
...
@@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
use_triton
:
bool
=
False
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
use_triton
=
use_trit
on
self
.
mla_wrapper
=
N
on
e
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
...
@@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
#print(compressed_kv.shape)
attn_weights
=
(
torch
.
matmul
(
q_pe
,
k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
compressed_kv
.
mT
))
*
self
.
softmax_scale
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
compressed_kv
=
compressed_kv
.
squeeze
(
1
)
"""
...
...
@@ -168,6 +173,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
attention_dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
einsum
(
'bhql,blc->bhqc'
,
attn_weights
,
compressed_kv
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
...
...
@@ -186,7 +192,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return
attn_output
,
None
,
past_key_value
def
forward_linux
(
def
forward_linux
_triton
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -267,7 +273,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# use triton attention kernel adapted from vLLM and SGLang for MQA
decode_attention_fwd_grouped
(
query_states
,
compressed_kv_with_k_pe
,
compressed_kv
,
attn_output
,
page_table
,
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
),
attn_logits
,
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
)
+
1
,
attn_logits
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
self
.
softmax_scale
,
past_key_value
.
page_size
)
...
...
@@ -326,6 +332,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
)))
q
=
q
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if
q_len
==
1
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
].
view
(
-
1
,
past_key_value
.
page_size
,
self
.
kv_lora_rank
)
k_pe
=
compressed_kv_with_k_pe
[:,
:,
:,
self
.
kv_lora_rank
:].
view
(
-
1
,
past_key_value
.
page_size
,
self
.
qk_rope_head_dim
)
# k_pe [max_pages, page_size, self.qk_rope_head_dim]
# compressed_kv [max_pages, page_size, self.kv_lora_rank]
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope
.
squeeze_
(
1
)
q_pe
.
squeeze_
(
1
)
# flash attn doesn't support head_dim bigger than 256, use flashinfer
if
self
.
mla_wrapper
is
None
:
self
.
mla_wrapper
=
MLAWrapperSingleton
.
get_instance
(
self
.
device
,
1
,
past_key_value
.
max_pages
,
use_cuda_graph
=
True
)
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
self
.
mla_wrapper
.
plan
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
attn_output
=
self
.
mla_wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
"""
k = (
torch.cat([compressed_kv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(self.num_heads, dim=1)
)
v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)
lens = position_ids.item() + 1
#print("lens", lens)
attn_ref, lse_ref = attention_ref(
1,
torch.cat([q_nope, q_pe], dim=-1),
k[:lens],
v[:lens],
False,
self.softmax_scale
)
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
"""
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
else
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
.
unsqueeze
(
0
)
k_pe
=
k_pe
[:,
:
q_len
]
compressed_kv
=
compressed_kv
[:,
:
q_len
]
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states_padded
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward_windows
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -403,7 +557,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
not
self
.
use_triton
:
#
os.name == 'nt'
if
os
.
name
==
'nt'
:
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
...
...
@@ -415,7 +569,19 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
**
kwargs
,
)
else
:
return
self
.
forward_linux
(
if
flashinfer_enabled
:
return
self
.
forward_linux_flashinfer
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
,
)
else
:
return
self
.
forward_linux_triton
(
hidden_states
,
attention_mask
,
position_ids
,
...
...
@@ -435,9 +601,10 @@ class KLlamaAttention(BaseInjectedModule):
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
def
apply_rotary_pos_emb
(
self
,
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
...
...
ktransformers/operators/base_operator.py
View file @
038bc308
...
...
@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__setattr__
(
self
,
"orig_module"
,
orig_module
)
object
.
__setattr__
(
self
,
"key"
,
key
)
object
.
__setattr__
(
self
,
"gguf_loader"
,
gguf_loader
)
object
.
__setattr__
(
self
,
"config"
,
config
)
object
.
__setattr__
(
self
,
"device"
,
device
)
object
.
__setattr__
(
self
,
"prefill_device"
,
prefill_device
)
object
.
__setattr__
(
self
,
"generate_device"
,
generate_device
)
object
.
__setattr__
(
self
,
"device"
,
generate_device
)
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
...
...
ktransformers/operators/experts.py
View file @
038bc308
...
...
@@ -119,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
output_cpu
:
Tensor
=
None
output_gpu_map
:
dict
=
{}
# Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
def
__init__
(
self
,
...
...
@@ -132,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
**
kwargs
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self
.
gguf_loader
=
gguf_loader
assert
device
.
lower
()
==
"cpu"
,
"KExpertsCPU can only be loaded on CPU"
self
.
n_routed_experts
=
n_routed_experts
self
.
out_device
=
out_device
...
...
@@ -532,7 +536,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
generate_device
:
str
=
"cpu"
,
generate_op
:
str
|
None
=
"KExpertsCPU"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KExpertsBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
if
generate_op
is
not
None
:
self
.
generate_experts
=
EXPERTS_MAP
[
generate_op
](
key
,
gguf_loader
,
config
,
len
(
orig_module
),
device
=
generate_device
,
**
kwargs
)
...
...
ktransformers/operators/flashinfer_wrapper.py
0 → 100644
View file @
038bc308
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.2
'''
import
torch
flashinfer_enabled
=
False
try
:
import
flashinfer
flashinfer_enabled
=
True
print
(
"found flashinfer"
)
except
ImportError
:
print
(
"flashinfer not found, use triton for linux"
)
import
math
def
attention_ref
(
batch_size
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
causal
:
bool
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
qo_len
=
q
.
shape
[
0
]
//
batch_size
kv_len
=
k
.
shape
[
0
]
//
batch_size
num_qo_heads
=
q
.
shape
[
1
]
head_dim_qk
=
q
.
shape
[
2
]
head_dim_vo
=
v
.
shape
[
2
]
logits
=
(
torch
.
einsum
(
"bmhd,bnhd->bhmn"
,
q
.
view
(
batch_size
,
qo_len
,
num_qo_heads
,
head_dim_qk
).
float
(),
k
.
view
(
batch_size
,
kv_len
,
num_qo_heads
,
head_dim_qk
).
float
(),
)
*
sm_scale
)
#print("attn weights", logits)
if
causal
:
mask
=
(
torch
.
arange
(
kv_len
-
qo_len
,
kv_len
).
unsqueeze
(
1
)
>=
torch
.
arange
(
0
,
kv_len
).
unsqueeze
(
0
)
).
to
(
q
.
device
)
else
:
mask
=
torch
.
ones
(
qo_len
,
kv_len
).
to
(
q
.
device
)
logits
=
logits
.
masked_fill
(
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
==
0
,
float
(
"-inf"
))
lse_ref
=
torch
.
logsumexp
(
logits
,
-
1
).
transpose
(
-
1
,
-
2
)
p
=
torch
.
softmax
(
logits
,
dim
=-
1
)
o_ref
=
(
torch
.
einsum
(
"bhmn,bnhd->bmhd"
,
p
,
v
.
view
(
batch_size
,
kv_len
,
num_qo_heads
,
head_dim_vo
).
float
(),
)
.
contiguous
()
.
view
(
batch_size
*
qo_len
,
num_qo_heads
,
head_dim_vo
)
.
to
(
q
)
)
return
o_ref
,
lse_ref
*
math
.
log2
(
math
.
e
)
class
MLAWrapper
():
def
__init__
(
self
,
max_batch_size
,
max_pages
,
use_cuda_graph
=
True
,
device
=
"cuda"
,
):
self
.
float_workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
device
)
self
.
max_batch_size
=
max_batch_size
self
.
max_pages
=
max_pages
if
use_cuda_graph
:
if
self
.
max_batch_size
==
1
:
self
.
qo_indptr_buf
=
torch
.
arange
(
0
,
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indptr_buf
=
torch
.
tensor
([
0
,
max_pages
],
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
self
.
qo_indptr_buf
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indptr_buf
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indices_buf
=
torch
.
empty
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_len_arr_buf
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
self
.
qo_indptr_buf
=
None
self
.
kv_indptr_buf
=
None
self
.
kv_indices_buf
=
None
self
.
kv_len_arr_buf
=
None
self
.
wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
self
.
float_workspace_buffer
,
use_cuda_graph
=
False
,
qo_indptr
=
self
.
qo_indptr_buf
,
kv_indptr
=
self
.
kv_indptr_buf
,
kv_indices
=
self
.
kv_indices_buf
,
kv_len_arr
=
self
.
kv_len_arr_buf
,
)
self
.
need_plan
=
True
def
plan
(
self
,
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
sm_scale
,
q_data_type
,
kv_data_type
,
):
if
qo_indptr
is
None
:
assert
self
.
max_batch_size
==
1
qo_indptr
=
self
.
qo_indptr_buf
if
kv_indptr
is
None
:
assert
self
.
max_batch_size
==
1
kv_indptr
=
self
.
kv_indptr_buf
if
kv_indices
is
None
:
assert
self
.
max_batch_size
==
1
kv_indices
=
self
.
kv_indices_buf
self
.
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
False
,
# causal is False for decoding
sm_scale
,
q_data_type
,
kv_data_type
,
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
)
class
MLAWrapperSingleton
():
wrappers
:
dict
=
{}
@
classmethod
def
get_instance
(
cls
,
device
,
*
args
,
**
kwargs
)
->
MLAWrapper
:
if
device
not
in
cls
.
wrappers
:
cls
.
make_instance
(
device
,
*
args
,
**
kwargs
)
return
cls
.
wrappers
[
device
]
@
classmethod
def
make_instance
(
cls
,
device
,
*
args
,
**
kwargs
):
cls
.
wrappers
[
device
]
=
MLAWrapper
(
*
args
,
**
kwargs
,
device
=
device
)
@
classmethod
def
plan_all
(
cls
,
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
sm_scale
,
q_data_type
,
kv_data_type
,):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
kv_len_arr_cur_device
=
kv_len_arr
.
to
(
device
)
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr_cur_device
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
sm_scale
,
q_data_type
,
kv_data_type
,)
if
__name__
==
"__main__"
:
max_batch_size
=
1
max_pages
=
1
page_size
=
64
num_heads
=
128
q_nope
=
torch
.
randn
((
1
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
1
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
=
torch
.
randn
((
max_pages
,
page_size
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
wrapper
=
MLAWrapperSingleton
.
get_instance
(
"cuda"
,
max_batch_size
,
max_pages
,
)
kv_len_arr
=
torch
.
tensor
([
10
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
None
,
None
,
None
,
kv_len_arr
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
print
(
k
[:
10
].
shape
)
print
(
v
[:
10
].
shape
)
attn_ref
,
lse_ref
=
attention_ref
(
max_batch_size
,
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
),
k
[:
10
],
v
[:
10
],
False
,
192
**
(
-
0.5
)
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/operators/gate.py
View file @
038bc308
...
...
@@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
...
...
ktransformers/operators/linear.py
View file @
038bc308
...
...
@@ -383,7 +383,7 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
prefill_op
:
str
|
None
=
"KLinearTorch"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KLinearBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
# build all the linear operators
if
prefill_op
is
not
None
:
...
...
ktransformers/util/utils.py
View file @
038bc308
...
...
@@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
from
ktransformers.models.custom_cache
import
StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
warm_uped
=
False
...
...
@@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
):
mode
=
'normal'
,
force_think
:
bool
=
False
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
torch
.
_dynamo
.
config
.
suppress_errors
=
True
...
...
@@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)
else
:
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
long
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
...
...
@@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
,
dtype
=
torch
.
long
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
,
dtype
=
torch
.
int32
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
seq_length
+=
1
...
...
@@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
if
i
>
1
and
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment