Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
ed843741
Commit
ed843741
authored
Mar 14, 2025
by
Azure-Tang
Browse files
merge main; Add torch q8 linear
parent
6c4ed591
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
540 additions
and
15 deletions
+540
-15
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+3
-3
ktransformers/operators/triton_attention_prefill.py
ktransformers/operators/triton_attention_prefill.py
+206
-0
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
...ormers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+3
-3
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+3
-3
ktransformers/tests/test_pytorch_q8.py
ktransformers/tests/test_pytorch_q8.py
+46
-0
ktransformers/util/vendors.py
ktransformers/util/vendors.py
+202
-0
setup.py
setup.py
+77
-6
No files found.
ktransformers/operators/triton_attention.py
View file @
ed843741
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
# Tanh is just a scaled sigmoid
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
# TODO: support hip
#
if
is_hip_
and Lk >= 576:
if
device_manager
.
gpu_vendor
==
GPUVendor
.
AMD
and
Lk
>=
576
:
#
BLOCK = 16
BLOCK
=
16
if
Lk
==
576
:
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
...
...
ktransformers/operators/triton_attention_prefill.py
0 → 100644
View file @
ed843741
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
import
triton
import
triton.language
as
tl
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
Out
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
end_n
=
(
cur_batch_seq_len
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
cur_batch_seq_len
)
)
for
start_n
in
range
(
0
,
block_mask
*
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
)
&
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:])),
0
,
float
(
"-inf"
),
)
else
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
,
0
,
float
(
"-inf"
)
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
8
:
BLOCK
=
128
else
:
BLOCK
=
64
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
),
BLOCK_N
=
BLOCK
,
IS_CAUSAL
=
is_causal
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lk
=
Lk
,
)
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
View file @
ed843741
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
kwargs
:
kwargs
:
generate_device
:
"
cuda"
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinear
Marlin
"
generate_op
:
"
KLinear
Q8
"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
...
@@ -22,9 +22,9 @@
...
@@ -22,9 +22,9 @@
replace
:
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
kwargs
:
generate_device
:
"
cu
da
"
generate_device
:
"
c
p
u"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinear
Marlin
"
generate_op
:
"
KLinear
Torch
"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
ed843741
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
kwargs
:
kwargs
:
generate_device
:
"
cuda"
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinear
Marlin
"
generate_op
:
"
KLinear
Q8
"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
...
@@ -23,9 +23,9 @@
...
@@ -23,9 +23,9 @@
replace
:
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
kwargs
:
generate_device
:
"
cu
da
"
generate_device
:
"
c
p
u"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinear
Marlin
"
generate_op
:
"
KLinear
CPUInfer
"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
...
...
ktransformers/tests/test_pytorch_q8.py
0 → 100644
View file @
ed843741
import
torch
# 定义一个包含线性层的浮点模型
class
LinearModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
in_features
,
out_features
)
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
# 创建浮点模型实例
in_features
=
64
out_features
=
128
model_fp32
=
LinearModel
(
in_features
,
out_features
)
# 创建量化模型实例
model_int8
=
torch
.
ao
.
quantization
.
quantize_dynamic
(
model_fp32
,
# 原始浮点模型
{
torch
.
nn
.
Linear
},
# 要量化的层类型集合
dtype
=
torch
.
qint8
# 量化的目标数据类型
)
# 测试模型
batch_size
=
32
input_fp32
=
torch
.
randn
(
1
,
batch_size
,
in_features
)
# 生成随机输入数据
output_int8
=
model_int8
(
input_fp32
)
# 通过量化模型运行数据
# 打印输出形状验证
print
(
f
"输入形状:
{
input_fp32
.
shape
}
"
)
print
(
f
"输出形状:
{
output_int8
.
shape
}
"
)
# 比较原始模型和量化模型的输出
with
torch
.
no_grad
():
output_fp32
=
model_fp32
(
input_fp32
)
print
(
f
"FP32输出的前几个值:
{
output_fp32
[
0
,
:
5
]
}
"
)
print
(
f
"INT8输出的前几个值:
{
output_int8
[
0
,
:
5
]
}
"
)
# 计算平均误差
error
=
torch
.
abs
(
output_fp32
-
output_int8
).
mean
().
item
()
print
(
f
"平均绝对误差:
{
error
}
"
)
# 打印模型类型信息
print
(
f
"量化前模型类型:
{
type
(
model_fp32
.
linear
)
}
"
)
print
(
f
"量化后模型类型:
{
type
(
model_int8
.
linear
)
}
"
)
\ No newline at end of file
ktransformers/util/vendors.py
0 → 100644
View file @
ed843741
from
__future__
import
annotations
from
enum
import
IntEnum
,
auto
from
typing
import
Optional
,
Union
,
List
import
torch
class
GPUVendor
(
IntEnum
):
NVIDIA
=
auto
()
AMD
=
auto
()
MooreThreads
=
auto
()
MetaX
=
auto
()
MUSA
=
auto
()
Unknown
=
auto
()
class
DeviceManager
:
"""
Device manager that provides a unified interface for handling different GPU vendors
"""
def
__init__
(
self
):
self
.
gpu_vendor
=
self
.
_detect_gpu_vendor
()
self
.
available_devices
=
self
.
_get_available_devices
()
def
_detect_gpu_vendor
(
self
)
->
GPUVendor
:
"""Detect GPU vendor type"""
if
not
torch
.
cuda
.
is_available
():
# Check MUSA availability (assuming a musa module exists)
try
:
import
musa
if
musa
.
is_available
():
return
GPUVendor
.
MUSA
except
(
ImportError
,
AttributeError
):
pass
return
GPUVendor
.
Unknown
device_name
=
torch
.
cuda
.
get_device_name
(
0
).
lower
()
if
any
(
name
in
device_name
for
name
in
[
"nvidia"
,
"geforce"
,
"quadro"
,
"tesla"
,
"titan"
,
"rtx"
,
"gtx"
]):
return
GPUVendor
.
NVIDIA
elif
any
(
name
in
device_name
for
name
in
[
"amd"
,
"radeon"
,
"rx"
,
"vega"
,
"instinct"
,
"firepro"
,
"mi"
]):
return
GPUVendor
.
AMD
elif
any
(
name
in
device_name
for
name
in
[
"mthreads"
,
"moore"
,
"mtt"
]):
return
GPUVendor
.
MooreThreads
elif
any
(
name
in
device_name
for
name
in
[
"metax"
,
"meta"
]):
return
GPUVendor
.
MetaX
elif
"musa"
in
device_name
:
return
GPUVendor
.
MUSA
# Backend check
try
:
if
hasattr
(
torch
.
version
,
'hip'
)
and
torch
.
version
.
hip
is
not
None
:
return
GPUVendor
.
AMD
elif
hasattr
(
torch
.
version
,
'cuda'
)
and
torch
.
version
.
cuda
is
not
None
:
return
GPUVendor
.
NVIDIA
except
:
pass
return
GPUVendor
.
Unknown
def
_get_available_devices
(
self
)
->
List
[
int
]:
"""Get list of available device indices"""
devices
=
[]
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
devices
=
list
(
range
(
torch
.
cuda
.
device_count
()))
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
devices
=
list
(
range
(
musa
.
device_count
()))
except
(
ImportError
,
AttributeError
):
pass
return
devices
def
get_device_str
(
self
,
device_id
:
Union
[
int
,
str
])
->
str
:
"""
Get device string for the given device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
"""
if
device_id
==
-
1
or
device_id
==
"cpu"
:
return
"cpu"
if
isinstance
(
device_id
,
int
):
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
if
device_id
<
torch
.
cuda
.
device_count
():
return
f
"cuda:
{
device_id
}
"
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
if
device_id
<
musa
.
device_count
():
return
f
"musa:
{
device_id
}
"
except
(
ImportError
,
AttributeError
):
pass
return
"cpu"
def
to_torch_device
(
self
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Convert device ID to torch.device object
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
device_str
=
self
.
get_device_str
(
device_id
)
# Handle MUSA device
if
device_str
.
startswith
(
"musa:"
):
try
:
import
musa
index
=
int
(
device_str
.
split
(
":"
)[
-
1
])
return
musa
.
device
(
index
)
except
(
ImportError
,
ValueError
,
AttributeError
):
return
torch
.
device
(
"cpu"
)
# Standard PyTorch device
return
torch
.
device
(
device_str
)
def
move_tensor_to_device
(
self
,
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
device
=
self
.
to_torch_device
(
device_id
)
return
tensor
.
to
(
device
)
def
is_available
(
self
,
index
:
int
=
0
)
->
bool
:
"""
Check if device at specified index is available
Args:
index: Device index to check
Returns:
True if the device is available, False otherwise
"""
if
index
<
0
:
return
True
# CPU is always available
return
index
in
self
.
available_devices
def
get_all_devices
(
self
)
->
List
[
int
]:
"""
Get all available device indices
Returns:
List of available device indices (0, 1, 2, etc.)
"""
return
self
.
available_devices
# Create global device manager instance
device_manager
=
DeviceManager
()
# Convenience functions
def
get_device
(
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Get torch.device object for the specified device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
return
device_manager
.
to_torch_device
(
device_id
)
def
to_device
(
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
return
device_manager
.
move_tensor_to_device
(
tensor
,
device_id
)
# Get devices
cpu_device
=
get_device
(
-
1
)
# CPU using index -1
cpu_device2
=
get_device
(
"cpu"
)
# CPU using string "cpu"
gpu0
=
get_device
(
0
)
# First GPU
# Move tensors
x
=
torch
.
randn
(
3
,
3
)
x_gpu
=
to_device
(
x
,
0
)
# Move to first GPU
x_cpu1
=
to_device
(
x
,
-
1
)
# Move to CPU using index -1
x_cpu2
=
to_device
(
x
,
"cpu"
)
# Move to CPU using string "cpu"
\ No newline at end of file
setup.py
View file @
ed843741
...
@@ -29,7 +29,7 @@ import torch.version
...
@@ -29,7 +29,7 @@ import torch.version
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
setuptools
import
setup
,
Extension
from
setuptools
import
setup
,
Extension
from
cpufeature.extension
import
CPUFeature
from
cpufeature.extension
import
CPUFeature
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
try
:
try
:
from
torch_musa.utils.simple_porting
import
SimplePorting
from
torch_musa.utils.simple_porting
import
SimplePorting
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
...
@@ -64,6 +64,70 @@ class VersionInfo:
...
@@ -64,6 +64,70 @@ class VersionInfo:
musa_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
musa_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
return
musa_version
return
musa_version
def
get_rocm_bare_metal_version
(
self
,
rocm_dir
):
"""
Get the ROCm version from the ROCm installation directory.
Args:
rocm_dir: Path to the ROCm installation directory
Returns:
A string representation of the ROCm version (e.g., "63" for ROCm 6.3)
"""
try
:
# Try using rocm_agent_enumerator to get version info
raw_output
=
subprocess
.
check_output
(
[
rocm_dir
+
"/bin/rocminfo"
,
"--version"
],
universal_newlines
=
True
,
stderr
=
subprocess
.
STDOUT
)
# Extract version number from output
match
=
re
.
search
(
r
'(\d+\.\d+)'
,
raw_output
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
# If rocminfo --version fails, try alternative methods
pass
try
:
# Try reading version from release file
with
open
(
os
.
path
.
join
(
rocm_dir
,
"share/doc/hip/version.txt"
),
"r"
)
as
f
:
version_str
=
f
.
read
().
strip
()
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
FileNotFoundError
,
IOError
):
pass
# If all else fails, try to extract from directory name
dir_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
rocm_dir
))
match
=
re
.
search
(
r
'rocm-(\d+\.\d+)'
,
dir_name
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
# Fallback to extracting from hipcc version
try
:
raw_output
=
subprocess
.
check_output
(
[
rocm_dir
+
"/bin/hipcc"
,
"--version"
],
universal_newlines
=
True
,
stderr
=
subprocess
.
STDOUT
)
match
=
re
.
search
(
r
'HIP version: (\d+\.\d+)'
,
raw_output
)
if
match
:
version_str
=
match
.
group
(
1
)
version
=
parse
(
version_str
)
rocm_version
=
f
"
{
version
.
major
}{
version
.
minor
}
"
return
rocm_version
except
(
subprocess
.
CalledProcessError
,
FileNotFoundError
):
pass
# If we still can't determine the version, raise an error
raise
ValueError
(
f
"Could not determine ROCm version from directory:
{
rocm_dir
}
"
)
def
get_cuda_bare_metal_version
(
self
,
cuda_dir
):
def
get_cuda_bare_metal_version
(
self
,
cuda_dir
):
raw_output
=
subprocess
.
check_output
(
raw_output
=
subprocess
.
check_output
(
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
...
@@ -148,11 +212,13 @@ class VersionInfo:
...
@@ -148,11 +212,13 @@ class VersionInfo:
cpu_instruct
=
self
.
get_cpu_instruct
()
cpu_instruct
=
self
.
get_cpu_instruct
()
backend_version
=
""
backend_version
=
""
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
backend_version
=
f
"
cu
{
self
.
get_cuda_bare_metal_version
(
CUDA_HOME
)
}
"
backend_version
=
f
""
elif
MUSA_HOME
is
not
None
:
elif
MUSA_HOME
is
not
None
:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
elif
ROCM_HOME
is
not
None
:
backend_version
=
f
"rocm
{
self
.
get_rocm_bare_metal_version
(
ROCM_HOME
)
}
"
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME
and
MUSA_HOME
are
not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME MUSA_HOME
ROCM_HOME all
not set."
)
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
if
full_version
:
if
full_version
:
return
package_version
return
package_version
...
@@ -247,8 +313,12 @@ class CMakeBuild(BuildExtension):
...
@@ -247,8 +313,12 @@ class CMakeBuild(BuildExtension):
cmake_args
+=
[
"-DKTRANSFORMERS_USE_CUDA=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_CUDA=ON"
]
elif
MUSA_HOME
is
not
None
:
elif
MUSA_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
elif
ROCM_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_ROCM=ON"
]
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
# log cmake_args
print
(
"CMake args:"
,
cmake_args
)
build_args
=
[]
build_args
=
[]
if
"CMAKE_ARGS"
in
os
.
environ
:
if
"CMAKE_ARGS"
in
os
.
environ
:
...
@@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
...
@@ -328,7 +398,7 @@ class CMakeBuild(BuildExtension):
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
)
)
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
or
ROCM_HOME
is
not
None
:
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'ktransformers/ktransformers_ext/cuda/binding.cpp'
,
'ktransformers/ktransformers_ext/cuda/binding.cpp'
,
...
@@ -338,7 +408,7 @@ if CUDA_HOME is not None:
...
@@ -338,7 +408,7 @@ if CUDA_HOME is not None:
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
'nvcc'
:
[
'nvcc'
:
[
'-O3'
,
'-O3'
,
'--use_fast_math'
,
#
'--use_fast_math',
'-Xcompiler'
,
'-fPIC'
,
'-Xcompiler'
,
'-fPIC'
,
'-DKTRANSFORMERS_USE_CUDA'
,
'-DKTRANSFORMERS_USE_CUDA'
,
]
]
...
@@ -371,6 +441,7 @@ else:
...
@@ -371,6 +441,7 @@ else:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
setup
(
setup
(
name
=
VersionInfo
.
PACKAGE_NAME
,
version
=
VersionInfo
().
get_package_version
(),
version
=
VersionInfo
().
get_package_version
(),
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
ext_modules
=
[
ext_modules
=
[
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment