Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
3986e2d2
Unverified
Commit
3986e2d2
authored
Mar 15, 2025
by
Azure
Committed by
GitHub
Mar 15, 2025
Browse files
Merge pull request #178 from fxzjshm/hip
[Feat] Port to ROCm/HIP
parents
8320ae7d
e5b001d7
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
819 additions
and
30 deletions
+819
-30
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+26
-12
ktransformers/operators/dynamic_attention.py
ktransformers/operators/dynamic_attention.py
+4
-1
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+174
-5
ktransformers/operators/models.py
ktransformers/operators/models.py
+4
-2
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+3
-3
ktransformers/operators/triton_attention_prefill.py
ktransformers/operators/triton_attention_prefill.py
+206
-0
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
...ormers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+1
-1
ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
...ormers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
+76
-0
ktransformers/tests/test_pytorch_q8.py
ktransformers/tests/test_pytorch_q8.py
+46
-0
ktransformers/util/vendors.py
ktransformers/util/vendors.py
+202
-0
setup.py
setup.py
+77
-6
No files found.
ktransformers/operators/attention.py
View file @
3986e2d2
...
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_func
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
try
:
from
flash_attn
import
flash_attn_func
except
:
pass
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention_prefill
import
context_attention_fwd
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
...
...
@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
kv_seq_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states_padded
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
# for bsz = 1
attn_output
=
torch
.
zeros
(
bsz
*
q_len
,
self
.
num_heads
,
self
.
v_head_dim
,
device
=
hidden_states
.
device
)
b_start_loc
=
torch
.
zeros
(
bsz
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
b_seq_len
=
torch
.
full
((
bsz
,),
q_len
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
max_input_len
=
q_len
context_attention_fwd
(
q
=
query_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
k
=
key_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
v
=
value_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
),
o
=
attn_output
,
b_start_loc
=
b_start_loc
,
b_seq_len
=
b_seq_len
,
max_input_len
=
max_input_len
,
is_causal
=
True
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
[:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
...
...
@@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
...
...
ktransformers/operators/dynamic_attention.py
View file @
3986e2d2
...
...
@@ -17,7 +17,10 @@ import logging
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
except
:
print
(
"falsh attn not found"
)
import
math
...
...
ktransformers/operators/linear.py
View file @
3986e2d2
...
...
@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import
cpuinfer_ext
from
ktransformers.operators.cpuinfer
import
CPUInfer
from
ktransformers.server.config.config
import
Config
from
typing
import
Dict
,
Tuple
,
Optional
,
Union
import
numpy
as
np
#class KLinearBase(BaseInjectedModule, ABC):
class
KLinearBase
(
ABC
):
...
...
@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
if
self
.
has_bias
:
self
.
bias
=
None
class
KLinearQ8
(
KLinearBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
device
:
str
=
"cuda"
,
**
kwargs
,
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
compute_dtype
=
torch
.
float32
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
bias
=
None
self
.
loaded
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
out_device
=
x
.
device
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
compute_dtype
)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant
=
self
.
_dequantize_weight
(
self
.
weight
,
self
.
weight_scale
,
bits
=
8
)
out
=
x
@
weight_dequant
.
T
if
self
.
has_bias
:
out
=
out
+
self
.
bias
return
out
.
to
(
dtype
=
orig_dtype
,
device
=
out_device
)
def
_dequantize_weight
(
self
,
q_matrix
,
scales
,
bits
=
8
):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if
not
isinstance
(
q_matrix
,
torch
.
Tensor
):
q_matrix
=
torch
.
tensor
(
q_matrix
,
dtype
=
torch
.
int8
)
if
not
isinstance
(
scales
,
torch
.
Tensor
):
scales
=
torch
.
tensor
(
scales
,
dtype
=
torch
.
float32
)
# Convert to correct dtype if needed
if
q_matrix
.
dtype
!=
torch
.
int8
:
q_matrix
=
q_matrix
.
to
(
torch
.
int8
)
if
scales
.
dtype
!=
torch
.
float32
:
scales
=
scales
.
to
(
torch
.
float32
)
# For Q4, ensure the values stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
7
,
7
)
rows
,
cols
=
q_matrix
.
shape
dequant_matrix
=
q_matrix
.
to
(
torch
.
float32
)
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix
=
dequant_matrix
*
scales_broadcast
return
dequant_matrix
def
_quantize_weight
(
self
,
matrix
,
bits
=
8
):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if
not
isinstance
(
matrix
,
torch
.
Tensor
):
matrix
=
torch
.
tensor
(
matrix
,
dtype
=
torch
.
float32
)
# Convert to float32 if needed
if
matrix
.
dtype
!=
torch
.
float32
:
matrix
=
matrix
.
to
(
torch
.
float32
)
# Get matrix shape
rows
,
cols
=
matrix
.
shape
# Determine quantization parameters based on bits
if
bits
==
8
:
max_int
=
127
qtype
=
torch
.
int8
elif
bits
==
4
:
max_int
=
7
qtype
=
torch
.
int8
# We'll still use int8 storage but limit to 4-bit range, wait for native support
else
:
raise
ValueError
(
"Quantization bits must be either 8 or 4"
)
scales
=
torch
.
zeros
(
cols
,
dtype
=
torch
.
float32
,
device
=
matrix
.
device
)
# Calculate max absolute value for each column
max_abs_vals
,
_
=
torch
.
max
(
torch
.
abs
(
matrix
),
dim
=
0
)
# Handle zero columns (avoid division by zero)
zero_cols
=
max_abs_vals
==
0
max_abs_vals
[
zero_cols
]
=
1.0
# Calculate scale factors for all columns at once
scales
=
max_abs_vals
/
max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply quantization to the entire matrix at once
q_matrix
=
torch
.
round
(
matrix
/
scales_broadcast
).
to
(
qtype
)
# For Q4, clamp values to ensure they stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
max_int
,
max_int
)
return
q_matrix
,
scales
def
load
(
self
,
w
:
Union
[
Dict
,
nn
.
Parameter
,
Tuple
,
None
]
=
None
,
device
:
Optional
[
str
]
=
None
):
if
self
.
loaded
:
return
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
try
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
)
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
try
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
compute_dtype
).
to
(
device
)
self
.
has_bias
=
True
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
weight
,
self
.
weight_scale
=
self
.
_quantize_weight
(
weight
,
bits
=
8
)
self
.
weight
=
self
.
weight
.
to
(
device
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
device
)
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
self
.
loaded
=
True
def
unload
(
self
):
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
_orig_weight
=
None
if
self
.
has_bias
:
self
.
bias
=
None
self
.
loaded
=
False
class
KLinearFP8
(
KLinearBase
):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
sort_indices
:
torch
.
Tensor
has_bias
:
bool
weight
:
torch
.
Tensor
scale_w
:
torch
.
Tensor
bias
:
torch
.
Tensor
def
__init__
(
self
,
...
...
@@ -468,6 +636,7 @@ LINEAR_MAP = {
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearQ8"
:
KLinearQ8
,
}
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
...
...
ktransformers/operators/models.py
View file @
3986e2d2
...
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer
,
DeepseekV2MoE
,
)
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
...
...
@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
#
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
...
...
@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f
=
0
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
...
...
ktransformers/operators/triton_attention.py
View file @
3986e2d2
...
...
@@ -6,7 +6,7 @@
import
triton
import
triton.language
as
tl
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
...
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
#
if
is_hip_
and Lk >= 576:
#
BLOCK = 16
if
device_manager
.
gpu_vendor
==
GPUVendor
.
AMD
and
Lk
>=
576
:
BLOCK
=
16
if
Lk
==
576
:
BLOCK_DMODEL
=
512
...
...
ktransformers/operators/triton_attention_prefill.py
0 → 100644
View file @
3986e2d2
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
import
triton
import
triton.language
as
tl
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
Out
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
end_n
=
(
cur_batch_seq_len
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
cur_batch_seq_len
)
)
for
start_n
in
range
(
0
,
block_mask
*
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
)
&
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:])),
0
,
float
(
"-inf"
),
)
else
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
,
0
,
float
(
"-inf"
)
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
8
:
BLOCK
=
128
else
:
BLOCK
=
64
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
),
BLOCK_N
=
BLOCK
,
IS_CAUSAL
=
is_causal
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lk
=
Lk
,
)
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
View file @
3986e2d2
...
...
@@ -22,7 +22,7 @@
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cu
da
"
generate_device
:
"
c
p
u"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
...
...
ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
0 → 100644
View file @
3986e2d2
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearCPUInfer"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearQ8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/tests/test_pytorch_q8.py
0 → 100644
View file @
3986e2d2
import
torch
# 定义一个包含线性层的浮点模型
class
LinearModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
in_features
,
out_features
)
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
# 创建浮点模型实例
in_features
=
64
out_features
=
128
model_fp32
=
LinearModel
(
in_features
,
out_features
)
# 创建量化模型实例
model_int8
=
torch
.
ao
.
quantization
.
quantize_dynamic
(
model_fp32
,
# 原始浮点模型
{
torch
.
nn
.
Linear
},
# 要量化的层类型集合
dtype
=
torch
.
qint8
# 量化的目标数据类型
)
# 测试模型
batch_size
=
32
input_fp32
=
torch
.
randn
(
1
,
batch_size
,
in_features
)
# 生成随机输入数据
output_int8
=
model_int8
(
input_fp32
)
# 通过量化模型运行数据
# 打印输出形状验证
print
(
f
"输入形状:
{
input_fp32
.
shape
}
"
)
print
(
f
"输出形状:
{
output_int8
.
shape
}
"
)
# 比较原始模型和量化模型的输出
with
torch
.
no_grad
():
output_fp32
=
model_fp32
(
input_fp32
)
print
(
f
"FP32输出的前几个值:
{
output_fp32
[
0
,
:
5
]
}
"
)
print
(
f
"INT8输出的前几个值:
{
output_int8
[
0
,
:
5
]
}
"
)
# 计算平均误差
error
=
torch
.
abs
(
output_fp32
-
output_int8
).
mean
().
item
()
print
(
f
"平均绝对误差:
{
error
}
"
)
# 打印模型类型信息
print
(
f
"量化前模型类型:
{
type
(
model_fp32
.
linear
)
}
"
)
print
(
f
"量化后模型类型:
{
type
(
model_int8
.
linear
)
}
"
)
\ No newline at end of file
ktransformers/util/vendors.py
0 → 100644
View file @
3986e2d2
from
__future__
import
annotations
from
enum
import
IntEnum
,
auto
from
typing
import
Optional
,
Union
,
List
import
torch
class
GPUVendor
(
IntEnum
):
NVIDIA
=
auto
()
AMD
=
auto
()
MooreThreads
=
auto
()
MetaX
=
auto
()
MUSA
=
auto
()
Unknown
=
auto
()
class
DeviceManager
:
"""
Device manager that provides a unified interface for handling different GPU vendors
"""
def
__init__
(
self
):
self
.
gpu_vendor
=
self
.
_detect_gpu_vendor
()
self
.
available_devices
=
self
.
_get_available_devices
()
def
_detect_gpu_vendor
(
self
)
->
GPUVendor
:
"""Detect GPU vendor type"""
if
not
torch
.
cuda
.
is_available
():
# Check MUSA availability (assuming a musa module exists)
try
:
import
musa
if
musa
.
is_available
():
return
GPUVendor
.
MUSA
except
(
ImportError
,
AttributeError
):
pass
return
GPUVendor
.
Unknown
device_name
=
torch
.
cuda
.
get_device_name
(
0
).
lower
()
if
any
(
name
in
device_name
for
name
in
[
"nvidia"
,
"geforce"
,
"quadro"
,
"tesla"
,
"titan"
,
"rtx"
,
"gtx"
]):
return
GPUVendor
.
NVIDIA
elif
any
(
name
in
device_name
for
name
in
[
"amd"
,
"radeon"
,
"rx"
,
"vega"
,
"instinct"
,
"firepro"
,
"mi"
]):
return
GPUVendor
.
AMD
elif
any
(
name
in
device_name
for
name
in
[
"mthreads"
,
"moore"
,
"mtt"
]):
return
GPUVendor
.
MooreThreads
elif
any
(
name
in
device_name
for
name
in
[
"metax"
,
"meta"
]):
return
GPUVendor
.
MetaX
elif
"musa"
in
device_name
:
return
GPUVendor
.
MUSA
# Backend check
try
:
if
hasattr
(
torch
.
version
,
'hip'
)
and
torch
.
version
.
hip
is
not
None
:
return
GPUVendor
.
AMD
elif
hasattr
(
torch
.
version
,
'cuda'
)
and
torch
.
version
.
cuda
is
not
None
:
return
GPUVendor
.
NVIDIA
except
:
pass
return
GPUVendor
.
Unknown
def
_get_available_devices
(
self
)
->
List
[
int
]:
"""Get list of available device indices"""
devices
=
[]
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
devices
=
list
(
range
(
torch
.
cuda
.
device_count
()))
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
devices
=
list
(
range
(
musa
.
device_count
()))
except
(
ImportError
,
AttributeError
):
pass
return
devices
def
get_device_str
(
self
,
device_id
:
Union
[
int
,
str
])
->
str
:
"""
Get device string for the given device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
"""
if
device_id
==
-
1
or
device_id
==
"cpu"
:
return
"cpu"
if
isinstance
(
device_id
,
int
):
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
if
device_id
<
torch
.
cuda
.
device_count
():
return
f
"cuda:
{
device_id
}
"
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
if
device_id
<
musa
.
device_count
():
return
f
"musa:
{
device_id
}
"
except
(
ImportError
,
AttributeError
):
pass
return
"cpu"
def
to_torch_device
(
self
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Convert device ID to torch.device object
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
device_str
=
self
.
get_device_str
(
device_id
)
# Handle MUSA device
if
device_str
.
startswith
(
"musa:"
):
try
:
import
musa
index
=
int
(
device_str
.
split
(
":"
)[
-
1
])
return
musa
.
device
(
index
)
except
(
ImportError
,
ValueError
,
AttributeError
):
return
torch
.
device
(
"cpu"
)
# Standard PyTorch device
return
torch
.
device
(
device_str
)
def
move_tensor_to_device
(
self
,
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
device
=
self
.
to_torch_device
(
device_id
)
return
tensor
.
to
(
device
)
def
is_available
(
self
,
index
:
int
=
0
)
->
bool
:
"""
Check if device at specified index is available
Args:
index: Device index to check
Returns:
True if the device is available, False otherwise
"""
if
index
<
0
:
return
True
# CPU is always available
return
index
in
self
.
available_devices
def
get_all_devices
(
self
)
->
List
[
int
]:
"""
Get all available device indices
Returns:
List of available device indices (0, 1, 2, etc.)
"""
return
self
.
available_devices
# Create global device manager instance
device_manager
=
DeviceManager
()
# Convenience functions
def
get_device
(
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Get torch.device object for the specified device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
return
device_manager
.
to_torch_device
(
device_id
)
def
to_device
(
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
return
device_manager
.
move_tensor_to_device
(
tensor
,
device_id
)
# Get devices
cpu_device
=
get_device
(
-
1
)
# CPU using index -1
cpu_device2
=
get_device
(
"cpu"
)
# CPU using string "cpu"
gpu0
=
get_device
(
0
)
# First GPU
# Move tensors
x
=
torch
.
randn
(
3
,
3
)
x_gpu
=
to_device
(
x
,
0
)
# Move to first GPU
x_cpu1
=
to_device
(
x
,
-
1
)
# Move to CPU using index -1
x_cpu2
=
to_device
(
x
,
"cpu"
)
# Move to CPU using string "cpu"
\ No newline at end of file
setup.py
View file @
3986e2d2
...
...
@@ -29,7 +29,7 @@ import torch.version
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
setuptools
import
setup
,
Extension
from
cpufeature.extension
import
CPUFeature
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
try
:
from
torch_musa.utils.simple_porting
import
SimplePorting
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
...
...
@@ -64,6 +64,70 @@ class VersionInfo:
musa_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
return
musa_version
def
get_rocm_bare_metal_version
(
self
,
rocm_dir
):
"""
Get the ROCm version from the ROCm installation directory.
Args:
rocm_dir: Path to the ROCm installation directory
Returns:
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
"""
try
:
# Try using rocm_agent_enumerator to get version info
raw_output
=
subprocess
.
check_output
(
[
rocm_dir
+
"/bin/rocminfo"
,
"--version"
],
universal_newlines
=
True
,
stderr
=
subprocess
.
STDOUT
)
# Extract version number from output
match
=
re
.
search
(
r
'(\d+\.\d+)'
,
raw_output
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
# If rocminfo --version fails, try alternative methods
pass
try
:
# Try reading version from release file
with
open
(
os
.
path
.
join
(
rocm_dir
,
"share/doc/hip/version.txt"
),
"r"
)
as
f
:
version_str
=
f
.
read
().
strip
()
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
FileNotFoundError
,
IOError
):
pass
# If all else fails, try to extract from directory name
dir_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
rocm_dir
))
match
=
re
.
search
(
r
'rocm-(\d+\.\d+)'
,
dir_name
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
# Fallback to extracting from hipcc version
try
:
raw_output
=
subprocess
.
check_output
(
[
rocm_dir
+
"/bin/hipcc"
,
"--version"
],
universal_newlines
=
True
,
stderr
=
subprocess
.
STDOUT
)
match
=
re
.
search
(
r
'HIP version: (\d+\.\d+)'
,
raw_output
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
pass
# If we still can't determine the version, raise an error
raise
ValueError
(
f
"Could not determine ROCm version from directory:
{
rocm_dir
}
"
)
def
get_cuda_bare_metal_version
(
self
,
cuda_dir
):
raw_output
=
subprocess
.
check_output
(
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
...
...
@@ -148,11 +212,13 @@ class VersionInfo:
cpu_instruct
=
self
.
get_cpu_instruct
()
backend_version
=
""
if
CUDA_HOME
is
not
None
:
backend_version
=
f
"
cu
{
self
.
get_cuda_bare_metal_version
(
CUDA_HOME
)
}
"
backend_version
=
f
""
elif
MUSA_HOME
is
not
None
:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
elif
ROCM_HOME
is
not
None
:
backend_version
=
f
"rocm
{
self
.
get_rocm_bare_metal_version
(
ROCM_HOME
)
}
"
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME
and
MUSA_HOME
are
not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME MUSA_HOME
ROCM_HOME all
not set."
)
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
if
full_version
:
return
package_version
...
...
@@ -247,9 +313,13 @@ class CMakeBuild(BuildExtension):
cmake_args
+=
[
"-DKTRANSFORMERS_USE_CUDA=ON"
]
elif
MUSA_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
elif
ROCM_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_ROCM=ON"
]
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
# log cmake_args
print
(
"CMake args:"
,
cmake_args
)
build_args
=
[]
if
"CMAKE_ARGS"
in
os
.
environ
:
cmake_args
+=
[
...
...
@@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
)
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
or
ROCM_HOME
is
not
None
:
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'ktransformers/ktransformers_ext/cuda/binding.cpp'
,
...
...
@@ -338,7 +408,7 @@ if CUDA_HOME is not None:
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
'nvcc'
:
[
'-O3'
,
'--use_fast_math'
,
#
'--use_fast_math',
'-Xcompiler'
,
'-fPIC'
,
'-DKTRANSFORMERS_USE_CUDA'
,
]
...
...
@@ -371,6 +441,7 @@ else:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
setup
(
name
=
VersionInfo
.
PACKAGE_NAME
,
version
=
VersionInfo
().
get_package_version
(),
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
ext_modules
=
[
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment