Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
038bc308
Commit
038bc308
authored
Feb 17, 2025
by
Atream
Browse files
fix precision bug imported by position_ids in 0.2.0
parent
b8452462
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
471 additions
and
45 deletions
+471
-45
ktransformers/local_chat.py
ktransformers/local_chat.py
+11
-3
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+0
-2
ktransformers/operators/RoPE.py
ktransformers/operators/RoPE.py
+9
-8
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+189
-22
ktransformers/operators/base_operator.py
ktransformers/operators/base_operator.py
+5
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+5
-1
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+240
-0
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+2
-2
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+1
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+9
-4
No files found.
ktransformers/local_chat.py
View file @
038bc308
...
@@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM
...
@@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
custom_models
=
{
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
@@ -170,9 +171,16 @@ def local_chat(
...
@@ -170,9 +171,16 @@ def local_chat(
torch
.
set_default_dtype
(
torch
.
set_default_dtype
(
torch
.
bfloat16
torch
.
bfloat16
)
# TODO: Remove this, replace dtype using config
)
# TODO: Remove this, replace dtype using config
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
,
force_think
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
:
)
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
ktransformers/models/custom_cache.py
View file @
038bc308
...
@@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache):
...
@@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache):
page_idx
=
cache_position
//
self
.
page_size
page_idx
=
cache_position
//
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out
[
page_idx
,
page_offset
,
:,
:
self
.
kv_lora_rank
]
=
key_states
k_out
[
page_idx
,
page_offset
,
:,
:
self
.
kv_lora_rank
]
=
key_states
k_out
[
page_idx
,
page_offset
,
:,
self
.
kv_lora_rank
:]
=
value_states
k_out
[
page_idx
,
page_offset
,
:,
self
.
kv_lora_rank
:]
=
value_states
return
k_out
,
self
.
page_table_list
[
layer_idx
]
return
k_out
,
self
.
page_table_list
[
layer_idx
]
...
...
ktransformers/operators/RoPE.py
View file @
038bc308
...
@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
...
@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
)
self
.
orig_module
.
__init__
(
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
orig_module
.
max_position_embeddings
,
orig_module
.
base
orig_module
.
dim
,
orig_module
.
max_position_embeddings
,
orig_module
.
base
...
@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
...
@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
)
self
.
generate_device
=
generate_device
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
self
.
prefill_device
=
prefill_device
...
@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
...
@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
)
self
.
orig_module
.
__init__
(
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
orig_module
.
dim
,
...
@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
...
@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
)
self
.
orig_module
.
__init__
(
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
orig_module
.
dim
,
...
@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
...
@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
# **kwargs,
# **kwargs,
# ):
# ):
# BaseInjectedModule.__init__(
# BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
# self, key, gguf_loader, config, orig_module,
prefill_device,
generate_device, **kwargs
# )
# )
# self.generate_device = generate_device
# self.generate_device = generate_device
# self.prefill_device = prefill_device
# self.prefill_device = prefill_device
...
@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
...
@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
)
self
.
generate_device
=
generate_device
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
self
.
prefill_device
=
prefill_device
...
@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
...
@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_
device
,
**
kwargs
)
)
self
.
orig_module
.
__init__
(
self
.
orig_module
.
__init__
(
orig_module
.
dim
,
orig_module
.
dim
,
...
...
ktransformers/operators/attention.py
View file @
038bc308
...
@@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
...
@@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
import
logging
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_with_kvcache
,
flash_attn_func
from
flash_attn
import
flash_attn_func
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
import
os
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
,
attention_ref
logger
=
logging
.
getLogger
(
"attention"
)
logger
=
logging
.
getLogger
(
"attention"
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
# Copied from transformers.models.llama.modeling_llama.rotate_half
...
@@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
chunck_size
:
int
=
1000
,
use_triton
:
bool
=
False
,
**
kwargs
):
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
use_triton
=
use_trit
on
self
.
mla_wrapper
=
N
on
e
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
@@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
#print(compressed_kv.shape)
#print(compressed_kv.shape)
attn_weights
=
(
torch
.
matmul
(
q_pe
,
k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
compressed_kv
.
mT
))
*
self
.
softmax_scale
attn_weights
=
(
torch
.
matmul
(
q_pe
,
k_pe
.
mT
)
+
torch
.
matmul
(
q_nope
,
compressed_kv
.
mT
))
*
self
.
softmax_scale
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
compressed_kv
=
compressed_kv
.
squeeze
(
1
)
compressed_kv
=
compressed_kv
.
squeeze
(
1
)
"""
"""
...
@@ -168,8 +173,9 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -168,8 +173,9 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
attention_dropout
,
training
=
self
.
training
attn_weights
,
p
=
self
.
attention_dropout
,
training
=
self
.
training
)
)
attn_output
=
torch
.
einsum
(
'bhql,blc->bhqc'
,
attn_weights
,
compressed_kv
)
attn_output
=
torch
.
einsum
(
'bhql,blc->bhqc'
,
attn_weights
,
compressed_kv
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
v_head_dim
):
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
v_head_dim
):
...
@@ -179,14 +185,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -179,14 +185,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward_linux
(
def
forward_linux
_triton
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -267,7 +273,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -267,7 +273,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# use triton attention kernel adapted from vLLM and SGLang for MQA
# use triton attention kernel adapted from vLLM and SGLang for MQA
decode_attention_fwd_grouped
(
query_states
,
compressed_kv_with_k_pe
,
compressed_kv
,
attn_output
,
decode_attention_fwd_grouped
(
query_states
,
compressed_kv_with_k_pe
,
compressed_kv
,
attn_output
,
page_table
,
page_table
,
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
),
attn_logits
,
position_ids
.
squeeze
(
0
).
to
(
torch
.
int32
)
+
1
,
attn_logits
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
4
,
#num_kv_splits # follow vLLM, fix it TODO
self
.
softmax_scale
,
self
.
softmax_scale
,
past_key_value
.
page_size
)
past_key_value
.
page_size
)
...
@@ -325,6 +331,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -325,6 +331,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
).
contiguous
()
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
)))
q
=
q
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
)
k_pe
=
k_pe
.
view
(
bsz
,
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
bsz
,
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
)
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
,
k_pe
,
cos
,
sin
,
unsqueeze_dim
=
2
)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if
q_len
==
1
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
].
view
(
-
1
,
past_key_value
.
page_size
,
self
.
kv_lora_rank
)
k_pe
=
compressed_kv_with_k_pe
[:,
:,
:,
self
.
kv_lora_rank
:].
view
(
-
1
,
past_key_value
.
page_size
,
self
.
qk_rope_head_dim
)
# k_pe [max_pages, page_size, self.qk_rope_head_dim]
# compressed_kv [max_pages, page_size, self.kv_lora_rank]
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope
.
squeeze_
(
1
)
q_pe
.
squeeze_
(
1
)
# flash attn doesn't support head_dim bigger than 256, use flashinfer
if
self
.
mla_wrapper
is
None
:
self
.
mla_wrapper
=
MLAWrapperSingleton
.
get_instance
(
self
.
device
,
1
,
past_key_value
.
max_pages
,
use_cuda_graph
=
True
)
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
self
.
mla_wrapper
.
plan
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
attn_output
=
self
.
mla_wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
"""
k = (
torch.cat([compressed_kv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(self.num_heads, dim=1)
)
v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)
lens = position_ids.item() + 1
#print("lens", lens)
attn_ref, lse_ref = attention_ref(
1,
torch.cat([q_nope, q_pe], dim=-1),
k[:lens],
v[:lens],
False,
self.softmax_scale
)
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
"""
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
else
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
k_pe
.
squeeze
(
0
)
compressed_kv
.
squeeze
(
0
)
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
k_pe
.
unsqueeze
(
0
)
compressed_kv
.
unsqueeze
(
0
)
k_pe
=
k_pe
[:,
:
q_len
]
compressed_kv
=
compressed_kv
[:,
:
q_len
]
kv
=
(
self
.
kv_b_proj
(
compressed_kv
)
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
)
k_nope
,
value_states
=
torch
.
split
(
kv
,
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
query_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
query_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
q_nope
query_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
q_pe
key_states
=
k_pe
.
new_empty
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
key_states
[:,
:,
:,
:
self
.
qk_nope_head_dim
]
=
k_nope
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states_padded
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward_windows
(
def
forward_windows
(
self
,
self
,
...
@@ -403,7 +557,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -403,7 +557,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
not
self
.
use_triton
:
#
os.name == 'nt'
if
os
.
name
==
'nt'
:
return
self
.
forward_windows
(
return
self
.
forward_windows
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
...
@@ -415,16 +569,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -415,16 +569,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
return
self
.
forward_linux
(
if
flashinfer_enabled
:
hidden_states
,
return
self
.
forward_linux_flashinfer
(
attention_mask
,
hidden_states
,
position_ids
,
attention_mask
,
past_key_value
,
position_ids
,
output_attentions
,
past_key_value
,
use_cache
,
output_attentions
,
cache_position
,
use_cache
,
**
kwargs
,
cache_position
,
)
**
kwargs
,
)
else
:
return
self
.
forward_linux_triton
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
,
)
class
KLlamaAttention
(
BaseInjectedModule
):
class
KLlamaAttention
(
BaseInjectedModule
):
...
@@ -435,9 +601,10 @@ class KLlamaAttention(BaseInjectedModule):
...
@@ -435,9 +601,10 @@ class KLlamaAttention(BaseInjectedModule):
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_
device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
orig_module
.
layer_idx
)
def
apply_rotary_pos_emb
(
self
,
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
def
apply_rotary_pos_emb
(
self
,
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
...
...
ktransformers/operators/base_operator.py
View file @
038bc308
...
@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
...
@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__setattr__
(
self
,
"orig_module"
,
orig_module
)
nn
.
Module
.
__setattr__
(
self
,
"orig_module"
,
orig_module
)
object
.
__setattr__
(
self
,
"key"
,
key
)
object
.
__setattr__
(
self
,
"key"
,
key
)
object
.
__setattr__
(
self
,
"gguf_loader"
,
gguf_loader
)
object
.
__setattr__
(
self
,
"gguf_loader"
,
gguf_loader
)
object
.
__setattr__
(
self
,
"config"
,
config
)
object
.
__setattr__
(
self
,
"config"
,
config
)
object
.
__setattr__
(
self
,
"device"
,
device
)
object
.
__setattr__
(
self
,
"prefill_device"
,
prefill_device
)
object
.
__setattr__
(
self
,
"generate_device"
,
generate_device
)
object
.
__setattr__
(
self
,
"device"
,
generate_device
)
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
...
...
ktransformers/operators/experts.py
View file @
038bc308
...
@@ -119,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
...
@@ -119,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
output_cpu
:
Tensor
=
None
output_cpu
:
Tensor
=
None
output_gpu_map
:
dict
=
{}
# Manage output tensor buffer on different gpu
output_gpu_map
:
dict
=
{}
# Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -132,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
...
@@ -132,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
**
kwargs
**
kwargs
):
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self
.
gguf_loader
=
gguf_loader
assert
device
.
lower
()
==
"cpu"
,
"KExpertsCPU can only be loaded on CPU"
assert
device
.
lower
()
==
"cpu"
,
"KExpertsCPU can only be loaded on CPU"
self
.
n_routed_experts
=
n_routed_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
out_device
=
out_device
self
.
out_device
=
out_device
...
@@ -532,7 +536,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
...
@@ -532,7 +536,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
generate_device
:
str
=
"cpu"
,
generate_device
:
str
=
"cpu"
,
generate_op
:
str
|
None
=
"KExpertsCPU"
,
generate_op
:
str
|
None
=
"KExpertsCPU"
,
**
kwargs
):
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KExpertsBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KExpertsBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
if
generate_op
is
not
None
:
if
generate_op
is
not
None
:
self
.
generate_experts
=
EXPERTS_MAP
[
generate_op
](
key
,
gguf_loader
,
config
,
len
(
orig_module
),
device
=
generate_device
,
**
kwargs
)
self
.
generate_experts
=
EXPERTS_MAP
[
generate_op
](
key
,
gguf_loader
,
config
,
len
(
orig_module
),
device
=
generate_device
,
**
kwargs
)
...
...
ktransformers/operators/flashinfer_wrapper.py
0 → 100644
View file @
038bc308
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.2
'''
import
torch
flashinfer_enabled
=
False
try
:
import
flashinfer
flashinfer_enabled
=
True
print
(
"found flashinfer"
)
except
ImportError
:
print
(
"flashinfer not found, use triton for linux"
)
import
math
def
attention_ref
(
batch_size
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
causal
:
bool
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
qo_len
=
q
.
shape
[
0
]
//
batch_size
kv_len
=
k
.
shape
[
0
]
//
batch_size
num_qo_heads
=
q
.
shape
[
1
]
head_dim_qk
=
q
.
shape
[
2
]
head_dim_vo
=
v
.
shape
[
2
]
logits
=
(
torch
.
einsum
(
"bmhd,bnhd->bhmn"
,
q
.
view
(
batch_size
,
qo_len
,
num_qo_heads
,
head_dim_qk
).
float
(),
k
.
view
(
batch_size
,
kv_len
,
num_qo_heads
,
head_dim_qk
).
float
(),
)
*
sm_scale
)
#print("attn weights", logits)
if
causal
:
mask
=
(
torch
.
arange
(
kv_len
-
qo_len
,
kv_len
).
unsqueeze
(
1
)
>=
torch
.
arange
(
0
,
kv_len
).
unsqueeze
(
0
)
).
to
(
q
.
device
)
else
:
mask
=
torch
.
ones
(
qo_len
,
kv_len
).
to
(
q
.
device
)
logits
=
logits
.
masked_fill
(
mask
.
unsqueeze
(
0
).
unsqueeze
(
0
)
==
0
,
float
(
"-inf"
))
lse_ref
=
torch
.
logsumexp
(
logits
,
-
1
).
transpose
(
-
1
,
-
2
)
p
=
torch
.
softmax
(
logits
,
dim
=-
1
)
o_ref
=
(
torch
.
einsum
(
"bhmn,bnhd->bmhd"
,
p
,
v
.
view
(
batch_size
,
kv_len
,
num_qo_heads
,
head_dim_vo
).
float
(),
)
.
contiguous
()
.
view
(
batch_size
*
qo_len
,
num_qo_heads
,
head_dim_vo
)
.
to
(
q
)
)
return
o_ref
,
lse_ref
*
math
.
log2
(
math
.
e
)
class
MLAWrapper
():
def
__init__
(
self
,
max_batch_size
,
max_pages
,
use_cuda_graph
=
True
,
device
=
"cuda"
,
):
self
.
float_workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
device
)
self
.
max_batch_size
=
max_batch_size
self
.
max_pages
=
max_pages
if
use_cuda_graph
:
if
self
.
max_batch_size
==
1
:
self
.
qo_indptr_buf
=
torch
.
arange
(
0
,
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indptr_buf
=
torch
.
tensor
([
0
,
max_pages
],
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
self
.
qo_indptr_buf
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indptr_buf
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indices_buf
=
torch
.
empty
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_len_arr_buf
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
self
.
qo_indptr_buf
=
None
self
.
kv_indptr_buf
=
None
self
.
kv_indices_buf
=
None
self
.
kv_len_arr_buf
=
None
self
.
wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
self
.
float_workspace_buffer
,
use_cuda_graph
=
False
,
qo_indptr
=
self
.
qo_indptr_buf
,
kv_indptr
=
self
.
kv_indptr_buf
,
kv_indices
=
self
.
kv_indices_buf
,
kv_len_arr
=
self
.
kv_len_arr_buf
,
)
self
.
need_plan
=
True
def
plan
(
self
,
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
sm_scale
,
q_data_type
,
kv_data_type
,
):
if
qo_indptr
is
None
:
assert
self
.
max_batch_size
==
1
qo_indptr
=
self
.
qo_indptr_buf
if
kv_indptr
is
None
:
assert
self
.
max_batch_size
==
1
kv_indptr
=
self
.
kv_indptr_buf
if
kv_indices
is
None
:
assert
self
.
max_batch_size
==
1
kv_indices
=
self
.
kv_indices_buf
self
.
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
False
,
# causal is False for decoding
sm_scale
,
q_data_type
,
kv_data_type
,
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
)
class
MLAWrapperSingleton
():
wrappers
:
dict
=
{}
@
classmethod
def
get_instance
(
cls
,
device
,
*
args
,
**
kwargs
)
->
MLAWrapper
:
if
device
not
in
cls
.
wrappers
:
cls
.
make_instance
(
device
,
*
args
,
**
kwargs
)
return
cls
.
wrappers
[
device
]
@
classmethod
def
make_instance
(
cls
,
device
,
*
args
,
**
kwargs
):
cls
.
wrappers
[
device
]
=
MLAWrapper
(
*
args
,
**
kwargs
,
device
=
device
)
@
classmethod
def
plan_all
(
cls
,
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
sm_scale
,
q_data_type
,
kv_data_type
,):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
kv_len_arr_cur_device
=
kv_len_arr
.
to
(
device
)
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr_cur_device
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
page_size
,
sm_scale
,
q_data_type
,
kv_data_type
,)
if
__name__
==
"__main__"
:
max_batch_size
=
1
max_pages
=
1
page_size
=
64
num_heads
=
128
q_nope
=
torch
.
randn
((
1
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
1
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
=
torch
.
randn
((
max_pages
,
page_size
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
wrapper
=
MLAWrapperSingleton
.
get_instance
(
"cuda"
,
max_batch_size
,
max_pages
,
)
kv_len_arr
=
torch
.
tensor
([
10
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
None
,
None
,
None
,
kv_len_arr
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
print
(
k
[:
10
].
shape
)
print
(
v
[:
10
].
shape
)
attn_ref
,
lse_ref
=
attention_ref
(
max_batch_size
,
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
),
k
[:
10
],
v
[:
10
],
False
,
192
**
(
-
0.5
)
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/operators/gate.py
View file @
038bc308
...
@@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
...
@@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
self
.
prefill_device
=
prefill_device
...
...
ktransformers/operators/linear.py
View file @
038bc308
...
@@ -383,7 +383,7 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
...
@@ -383,7 +383,7 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
prefill_op
:
str
|
None
=
"KLinearTorch"
,
prefill_op
:
str
|
None
=
"KLinearTorch"
,
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KLinearBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KLinearBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
# build all the linear operators
# build all the linear operators
if
prefill_op
is
not
None
:
if
prefill_op
is
not
None
:
...
...
ktransformers/util/utils.py
View file @
038bc308
...
@@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
...
@@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
from
ktransformers.models.custom_cache
import
StaticCache
from
ktransformers.models.custom_cache
import
StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
warm_uped
=
False
warm_uped
=
False
...
@@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
...
@@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
):
mode
=
'normal'
,
force_think
:
bool
=
False
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
torch
.
_dynamo
.
config
.
suppress_errors
=
True
torch
.
_dynamo
.
config
.
suppress_errors
=
True
...
@@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)
)
else
:
else
:
past_key_values
=
None
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
long
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
)
...
@@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids
[:,
seq_length
]
=
next_token
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
tokens
.
append
(
int
(
next_token
))
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
,
dtype
=
torch
.
long
)
cache_position
=
torch
.
tensor
([
seq_length
],
device
=
torch_device
,
dtype
=
torch
.
int32
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
position_ids
=
cache_position
.
unsqueeze
(
0
)
seq_length
+=
1
seq_length
+=
1
...
@@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped
=
True
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
if
i
>
1
and
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment