Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
9b76cab1
Unverified
Commit
9b76cab1
authored
Mar 15, 2025
by
ZiWei Yuan
Committed by
GitHub
Mar 15, 2025
Browse files
Merge pull request #898 from kvcache-ai/develop-0.2.3post2
Release 0.2.3post2
parents
dfe09b05
b5ef7c26
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
821 additions
and
31 deletions
+821
-31
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-1
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+26
-12
ktransformers/operators/dynamic_attention.py
ktransformers/operators/dynamic_attention.py
+4
-1
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+174
-5
ktransformers/operators/models.py
ktransformers/operators/models.py
+4
-2
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+3
-3
ktransformers/operators/triton_attention_prefill.py
ktransformers/operators/triton_attention_prefill.py
+206
-0
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
...ormers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+1
-1
ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
...ormers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
+76
-0
ktransformers/tests/test_pytorch_q8.py
ktransformers/tests/test_pytorch_q8.py
+46
-0
ktransformers/util/vendors.py
ktransformers/util/vendors.py
+202
-0
setup.py
setup.py
+77
-6
No files found.
ktransformers/local_chat.py
View file @
9b76cab1
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
custom_models
=
{
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
@@ -169,7 +170,7 @@ def local_chat(
...
@@ -169,7 +170,7 @@ def local_chat(
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
and
device_manager
.
gpu_vendor
==
GPUVendor
.
NVIDIA
:
generated
=
prefill_and_generate
(
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
...
...
ktransformers/operators/attention.py
View file @
9b76cab1
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import
logging
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_func
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
try
:
from
flash_attn
import
flash_attn_func
except
:
pass
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention_prefill
import
context_attention_fwd
import
os
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
if
flashinfer_enabled
:
...
@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
kv_seq_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_seq_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
# for bsz = 1
query_states
,
attn_output
=
torch
.
zeros
(
bsz
*
q_len
,
self
.
num_heads
,
self
.
v_head_dim
,
device
=
hidden_states
.
device
)
key_states
,
b_start_loc
=
torch
.
zeros
(
bsz
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
value_states_padded
,
b_seq_len
=
torch
.
full
((
bsz
,),
q_len
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
max_input_len
=
q_len
context_attention_fwd
(
q
=
query_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
k
=
key_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
v
=
value_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
),
o
=
attn_output
,
b_start_loc
=
b_start_loc
,
b_seq_len
=
b_seq_len
,
max_input_len
=
max_input_len
,
is_causal
=
True
)
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
[:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
...
@@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
return
self
.
forward_windows
(
return
self
.
forward_windows
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
...
...
ktransformers/operators/dynamic_attention.py
View file @
9b76cab1
...
@@ -17,7 +17,10 @@ import logging
...
@@ -17,7 +17,10 @@ import logging
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
except
:
print
(
"falsh attn not found"
)
import
math
import
math
...
...
ktransformers/operators/linear.py
View file @
9b76cab1
...
@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
...
@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import
cpuinfer_ext
import
cpuinfer_ext
from
ktransformers.operators.cpuinfer
import
CPUInfer
from
ktransformers.operators.cpuinfer
import
CPUInfer
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
typing
import
Dict
,
Tuple
,
Optional
,
Union
import
numpy
as
np
#class KLinearBase(BaseInjectedModule, ABC):
#class KLinearBase(BaseInjectedModule, ABC):
class
KLinearBase
(
ABC
):
class
KLinearBase
(
ABC
):
...
@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
...
@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
if
self
.
has_bias
:
if
self
.
has_bias
:
self
.
bias
=
None
self
.
bias
=
None
class
KLinearQ8
(
KLinearBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
device
:
str
=
"cuda"
,
**
kwargs
,
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
compute_dtype
=
torch
.
float32
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
bias
=
None
self
.
loaded
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
out_device
=
x
.
device
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
compute_dtype
)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant
=
self
.
_dequantize_weight
(
self
.
weight
,
self
.
weight_scale
,
bits
=
8
)
out
=
x
@
weight_dequant
.
T
if
self
.
has_bias
:
out
=
out
+
self
.
bias
return
out
.
to
(
dtype
=
orig_dtype
,
device
=
out_device
)
def
_dequantize_weight
(
self
,
q_matrix
,
scales
,
bits
=
8
):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if
not
isinstance
(
q_matrix
,
torch
.
Tensor
):
q_matrix
=
torch
.
tensor
(
q_matrix
,
dtype
=
torch
.
int8
)
if
not
isinstance
(
scales
,
torch
.
Tensor
):
scales
=
torch
.
tensor
(
scales
,
dtype
=
torch
.
float32
)
# Convert to correct dtype if needed
if
q_matrix
.
dtype
!=
torch
.
int8
:
q_matrix
=
q_matrix
.
to
(
torch
.
int8
)
if
scales
.
dtype
!=
torch
.
float32
:
scales
=
scales
.
to
(
torch
.
float32
)
# For Q4, ensure the values stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
7
,
7
)
rows
,
cols
=
q_matrix
.
shape
dequant_matrix
=
q_matrix
.
to
(
torch
.
float32
)
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix
=
dequant_matrix
*
scales_broadcast
return
dequant_matrix
def
_quantize_weight
(
self
,
matrix
,
bits
=
8
):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if
not
isinstance
(
matrix
,
torch
.
Tensor
):
matrix
=
torch
.
tensor
(
matrix
,
dtype
=
torch
.
float32
)
# Convert to float32 if needed
if
matrix
.
dtype
!=
torch
.
float32
:
matrix
=
matrix
.
to
(
torch
.
float32
)
# Get matrix shape
rows
,
cols
=
matrix
.
shape
# Determine quantization parameters based on bits
if
bits
==
8
:
max_int
=
127
qtype
=
torch
.
int8
elif
bits
==
4
:
max_int
=
7
qtype
=
torch
.
int8
# We'll still use int8 storage but limit to 4-bit range, wait for native support
else
:
raise
ValueError
(
"Quantization bits must be either 8 or 4"
)
scales
=
torch
.
zeros
(
cols
,
dtype
=
torch
.
float32
,
device
=
matrix
.
device
)
# Calculate max absolute value for each column
max_abs_vals
,
_
=
torch
.
max
(
torch
.
abs
(
matrix
),
dim
=
0
)
# Handle zero columns (avoid division by zero)
zero_cols
=
max_abs_vals
==
0
max_abs_vals
[
zero_cols
]
=
1.0
# Calculate scale factors for all columns at once
scales
=
max_abs_vals
/
max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply quantization to the entire matrix at once
q_matrix
=
torch
.
round
(
matrix
/
scales_broadcast
).
to
(
qtype
)
# For Q4, clamp values to ensure they stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
max_int
,
max_int
)
return
q_matrix
,
scales
def
load
(
self
,
w
:
Union
[
Dict
,
nn
.
Parameter
,
Tuple
,
None
]
=
None
,
device
:
Optional
[
str
]
=
None
):
if
self
.
loaded
:
return
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
try
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
)
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
try
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
compute_dtype
).
to
(
device
)
self
.
has_bias
=
True
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
weight
,
self
.
weight_scale
=
self
.
_quantize_weight
(
weight
,
bits
=
8
)
self
.
weight
=
self
.
weight
.
to
(
device
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
device
)
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
self
.
loaded
=
True
def
unload
(
self
):
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
_orig_weight
=
None
if
self
.
has_bias
:
self
.
bias
=
None
self
.
loaded
=
False
class
KLinearFP8
(
KLinearBase
):
class
KLinearFP8
(
KLinearBase
):
# this kernel requires special handling for weight
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
# Please load the weight file downloaded from KVCache.AI
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
sort_indices
:
torch
.
Tensor
has_bias
:
bool
has_bias
:
bool
weight
:
torch
.
Tensor
weight
:
torch
.
Tensor
scale_w
:
torch
.
Tensor
bias
:
torch
.
Tensor
bias
:
torch
.
Tensor
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -468,6 +636,7 @@ LINEAR_MAP = {
...
@@ -468,6 +636,7 @@ LINEAR_MAP = {
"KLinearTorch"
:
KLinearTorch
,
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearQ8"
:
KLinearQ8
,
}
}
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
...
...
ktransformers/operators/models.py
View file @
9b76cab1
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer
,
DeepseekV2DecoderLayer
,
DeepseekV2MoE
,
DeepseekV2MoE
,
)
)
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
...
@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
if
per_layer_prefill_flag
:
causal_mask
=
None
causal_mask
=
None
else
:
else
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
#
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
...
@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f
=
0
t_f
=
0
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
cur_device
=
self
.
transfer_map
[
i
]
...
...
ktransformers/operators/triton_attention.py
View file @
9b76cab1
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
# Tanh is just a scaled sigmoid
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
# TODO: support hip
#
if
is_hip_
and Lk >= 576:
if
device_manager
.
gpu_vendor
==
GPUVendor
.
AMD
and
Lk
>=
576
:
#
BLOCK = 16
BLOCK
=
16
if
Lk
==
576
:
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
...
...
ktransformers/operators/triton_attention_prefill.py
0 → 100644
View file @
9b76cab1
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
import
triton
import
triton.language
as
tl
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
Out
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
end_n
=
(
cur_batch_seq_len
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
cur_batch_seq_len
)
)
for
start_n
in
range
(
0
,
block_mask
*
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
)
&
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:])),
0
,
float
(
"-inf"
),
)
else
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
,
0
,
float
(
"-inf"
)
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
8
:
BLOCK
=
128
else
:
BLOCK
=
64
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
),
BLOCK_N
=
BLOCK
,
IS_CAUSAL
=
is_causal
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lk
=
Lk
,
)
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
View file @
9b76cab1
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
replace
:
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
kwargs
:
generate_device
:
"
cu
da
"
generate_device
:
"
c
p
u"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
...
...
ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
0 → 100644
View file @
9b76cab1
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearCPUInfer"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearQ8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/tests/test_pytorch_q8.py
0 → 100644
View file @
9b76cab1
import
torch
# 定义一个包含线性层的浮点模型
class
LinearModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
in_features
,
out_features
)
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
# 创建浮点模型实例
in_features
=
64
out_features
=
128
model_fp32
=
LinearModel
(
in_features
,
out_features
)
# 创建量化模型实例
model_int8
=
torch
.
ao
.
quantization
.
quantize_dynamic
(
model_fp32
,
# 原始浮点模型
{
torch
.
nn
.
Linear
},
# 要量化的层类型集合
dtype
=
torch
.
qint8
# 量化的目标数据类型
)
# 测试模型
batch_size
=
32
input_fp32
=
torch
.
randn
(
1
,
batch_size
,
in_features
)
# 生成随机输入数据
output_int8
=
model_int8
(
input_fp32
)
# 通过量化模型运行数据
# 打印输出形状验证
print
(
f
"输入形状:
{
input_fp32
.
shape
}
"
)
print
(
f
"输出形状:
{
output_int8
.
shape
}
"
)
# 比较原始模型和量化模型的输出
with
torch
.
no_grad
():
output_fp32
=
model_fp32
(
input_fp32
)
print
(
f
"FP32输出的前几个值:
{
output_fp32
[
0
,
:
5
]
}
"
)
print
(
f
"INT8输出的前几个值:
{
output_int8
[
0
,
:
5
]
}
"
)
# 计算平均误差
error
=
torch
.
abs
(
output_fp32
-
output_int8
).
mean
().
item
()
print
(
f
"平均绝对误差:
{
error
}
"
)
# 打印模型类型信息
print
(
f
"量化前模型类型:
{
type
(
model_fp32
.
linear
)
}
"
)
print
(
f
"量化后模型类型:
{
type
(
model_int8
.
linear
)
}
"
)
\ No newline at end of file
ktransformers/util/vendors.py
0 → 100644
View file @
9b76cab1
from
__future__
import
annotations
from
enum
import
IntEnum
,
auto
from
typing
import
Optional
,
Union
,
List
import
torch
class
GPUVendor
(
IntEnum
):
NVIDIA
=
auto
()
AMD
=
auto
()
MooreThreads
=
auto
()
MetaX
=
auto
()
MUSA
=
auto
()
Unknown
=
auto
()
class
DeviceManager
:
"""
Device manager that provides a unified interface for handling different GPU vendors
"""
def
__init__
(
self
):
self
.
gpu_vendor
=
self
.
_detect_gpu_vendor
()
self
.
available_devices
=
self
.
_get_available_devices
()
def
_detect_gpu_vendor
(
self
)
->
GPUVendor
:
"""Detect GPU vendor type"""
if
not
torch
.
cuda
.
is_available
():
# Check MUSA availability (assuming a musa module exists)
try
:
import
musa
if
musa
.
is_available
():
return
GPUVendor
.
MUSA
except
(
ImportError
,
AttributeError
):
pass
return
GPUVendor
.
Unknown
device_name
=
torch
.
cuda
.
get_device_name
(
0
).
lower
()
if
any
(
name
in
device_name
for
name
in
[
"nvidia"
,
"geforce"
,
"quadro"
,
"tesla"
,
"titan"
,
"rtx"
,
"gtx"
]):
return
GPUVendor
.
NVIDIA
elif
any
(
name
in
device_name
for
name
in
[
"amd"
,
"radeon"
,
"rx"
,
"vega"
,
"instinct"
,
"firepro"
,
"mi"
]):
return
GPUVendor
.
AMD
elif
any
(
name
in
device_name
for
name
in
[
"mthreads"
,
"moore"
,
"mtt"
]):
return
GPUVendor
.
MooreThreads
elif
any
(
name
in
device_name
for
name
in
[
"metax"
,
"meta"
]):
return
GPUVendor
.
MetaX
elif
"musa"
in
device_name
:
return
GPUVendor
.
MUSA
# Backend check
try
:
if
hasattr
(
torch
.
version
,
'hip'
)
and
torch
.
version
.
hip
is
not
None
:
return
GPUVendor
.
AMD
elif
hasattr
(
torch
.
version
,
'cuda'
)
and
torch
.
version
.
cuda
is
not
None
:
return
GPUVendor
.
NVIDIA
except
:
pass
return
GPUVendor
.
Unknown
def
_get_available_devices
(
self
)
->
List
[
int
]:
"""Get list of available device indices"""
devices
=
[]
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
devices
=
list
(
range
(
torch
.
cuda
.
device_count
()))
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
devices
=
list
(
range
(
musa
.
device_count
()))
except
(
ImportError
,
AttributeError
):
pass
return
devices
def
get_device_str
(
self
,
device_id
:
Union
[
int
,
str
])
->
str
:
"""
Get device string for the given device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
"""
if
device_id
==
-
1
or
device_id
==
"cpu"
:
return
"cpu"
if
isinstance
(
device_id
,
int
):
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
if
device_id
<
torch
.
cuda
.
device_count
():
return
f
"cuda:
{
device_id
}
"
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
if
device_id
<
musa
.
device_count
():
return
f
"musa:
{
device_id
}
"
except
(
ImportError
,
AttributeError
):
pass
return
"cpu"
def
to_torch_device
(
self
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Convert device ID to torch.device object
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
device_str
=
self
.
get_device_str
(
device_id
)
# Handle MUSA device
if
device_str
.
startswith
(
"musa:"
):
try
:
import
musa
index
=
int
(
device_str
.
split
(
":"
)[
-
1
])
return
musa
.
device
(
index
)
except
(
ImportError
,
ValueError
,
AttributeError
):
return
torch
.
device
(
"cpu"
)
# Standard PyTorch device
return
torch
.
device
(
device_str
)
def
move_tensor_to_device
(
self
,
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
device
=
self
.
to_torch_device
(
device_id
)
return
tensor
.
to
(
device
)
def
is_available
(
self
,
index
:
int
=
0
)
->
bool
:
"""
Check if device at specified index is available
Args:
index: Device index to check
Returns:
True if the device is available, False otherwise
"""
if
index
<
0
:
return
True
# CPU is always available
return
index
in
self
.
available_devices
def
get_all_devices
(
self
)
->
List
[
int
]:
"""
Get all available device indices
Returns:
List of available device indices (0, 1, 2, etc.)
"""
return
self
.
available_devices
# Create global device manager instance
device_manager
=
DeviceManager
()
# Convenience functions
def
get_device
(
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Get torch.device object for the specified device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
return
device_manager
.
to_torch_device
(
device_id
)
def
to_device
(
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
return
device_manager
.
move_tensor_to_device
(
tensor
,
device_id
)
# Get devices
cpu_device
=
get_device
(
-
1
)
# CPU using index -1
cpu_device2
=
get_device
(
"cpu"
)
# CPU using string "cpu"
gpu0
=
get_device
(
0
)
# First GPU
# Move tensors
x
=
torch
.
randn
(
3
,
3
)
x_gpu
=
to_device
(
x
,
0
)
# Move to first GPU
x_cpu1
=
to_device
(
x
,
-
1
)
# Move to CPU using index -1
x_cpu2
=
to_device
(
x
,
"cpu"
)
# Move to CPU using string "cpu"
\ No newline at end of file
setup.py
View file @
9b76cab1
...
@@ -29,7 +29,7 @@ import torch.version
...
@@ -29,7 +29,7 @@ import torch.version
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
setuptools
import
setup
,
Extension
from
setuptools
import
setup
,
Extension
from
cpufeature.extension
import
CPUFeature
from
cpufeature.extension
import
CPUFeature
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
try
:
try
:
from
torch_musa.utils.simple_porting
import
SimplePorting
from
torch_musa.utils.simple_porting
import
SimplePorting
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
...
@@ -64,6 +64,70 @@ class VersionInfo:
...
@@ -64,6 +64,70 @@ class VersionInfo:
musa_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
musa_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
return
musa_version
return
musa_version
def
get_rocm_bare_metal_version
(
self
,
rocm_dir
):
"""
Get the ROCm version from the ROCm installation directory.
Args:
rocm_dir: Path to the ROCm installation directory
Returns:
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
"""
try
:
# Try using rocm_agent_enumerator to get version info
raw_output
=
subprocess
.
check_output
(
[
rocm_dir
+
"/bin/rocminfo"
,
"--version"
],
universal_newlines
=
True
,
stderr
=
subprocess
.
STDOUT
)
# Extract version number from output
match
=
re
.
search
(
r
'(\d+\.\d+)'
,
raw_output
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
# If rocminfo --version fails, try alternative methods
pass
try
:
# Try reading version from release file
with
open
(
os
.
path
.
join
(
rocm_dir
,
"share/doc/hip/version.txt"
),
"r"
)
as
f
:
version_str
=
f
.
read
().
strip
()
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
FileNotFoundError
,
IOError
):
pass
# If all else fails, try to extract from directory name
dir_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
rocm_dir
))
match
=
re
.
search
(
r
'rocm-(\d+\.\d+)'
,
dir_name
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
# Fallback to extracting from hipcc version
try
:
raw_output
=
subprocess
.
check_output
(
[
rocm_dir
+
"/bin/hipcc"
,
"--version"
],
universal_newlines
=
True
,
stderr
=
subprocess
.
STDOUT
)
match
=
re
.
search
(
r
'HIP version: (\d+\.\d+)'
,
raw_output
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
pass
# If we still can't determine the version, raise an error
raise
ValueError
(
f
"Could not determine ROCm version from directory:
{
rocm_dir
}
"
)
def
get_cuda_bare_metal_version
(
self
,
cuda_dir
):
def
get_cuda_bare_metal_version
(
self
,
cuda_dir
):
raw_output
=
subprocess
.
check_output
(
raw_output
=
subprocess
.
check_output
(
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
...
@@ -148,11 +212,13 @@ class VersionInfo:
...
@@ -148,11 +212,13 @@ class VersionInfo:
cpu_instruct
=
self
.
get_cpu_instruct
()
cpu_instruct
=
self
.
get_cpu_instruct
()
backend_version
=
""
backend_version
=
""
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
backend_version
=
f
"
cu
{
self
.
get_cuda_bare_metal_version
(
CUDA_HOME
)
}
"
backend_version
=
f
""
elif
MUSA_HOME
is
not
None
:
elif
MUSA_HOME
is
not
None
:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
elif
ROCM_HOME
is
not
None
:
backend_version
=
f
"rocm
{
self
.
get_rocm_bare_metal_version
(
ROCM_HOME
)
}
"
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME
and
MUSA_HOME
are
not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME MUSA_HOME
ROCM_HOME all
not set."
)
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
if
full_version
:
if
full_version
:
return
package_version
return
package_version
...
@@ -247,8 +313,12 @@ class CMakeBuild(BuildExtension):
...
@@ -247,8 +313,12 @@ class CMakeBuild(BuildExtension):
cmake_args
+=
[
"-DKTRANSFORMERS_USE_CUDA=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_CUDA=ON"
]
elif
MUSA_HOME
is
not
None
:
elif
MUSA_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
elif
ROCM_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_ROCM=ON"
]
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
# log cmake_args
print
(
"CMake args:"
,
cmake_args
)
build_args
=
[]
build_args
=
[]
if
"CMAKE_ARGS"
in
os
.
environ
:
if
"CMAKE_ARGS"
in
os
.
environ
:
...
@@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
...
@@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
)
)
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
or
ROCM_HOME
is
not
None
:
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'ktransformers/ktransformers_ext/cuda/binding.cpp'
,
'ktransformers/ktransformers_ext/cuda/binding.cpp'
,
...
@@ -338,7 +408,7 @@ if CUDA_HOME is not None:
...
@@ -338,7 +408,7 @@ if CUDA_HOME is not None:
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
'nvcc'
:
[
'nvcc'
:
[
'-O3'
,
'-O3'
,
'--use_fast_math'
,
#
'--use_fast_math',
'-Xcompiler'
,
'-fPIC'
,
'-Xcompiler'
,
'-fPIC'
,
'-DKTRANSFORMERS_USE_CUDA'
,
'-DKTRANSFORMERS_USE_CUDA'
,
]
]
...
@@ -371,6 +441,7 @@ else:
...
@@ -371,6 +441,7 @@ else:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
setup
(
setup
(
name
=
VersionInfo
.
PACKAGE_NAME
,
version
=
VersionInfo
().
get_package_version
(),
version
=
VersionInfo
().
get_package_version
(),
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
ext_modules
=
[
ext_modules
=
[
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment