Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
086a9d1c
Commit
086a9d1c
authored
Mar 13, 2025
by
Azure-Tang
Browse files
Add vendor control
parent
c009512a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
446 additions
and
17 deletions
+446
-17
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+1
-0
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-1
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+26
-11
ktransformers/operators/dynamic_attention.py
ktransformers/operators/dynamic_attention.py
+4
-1
ktransformers/operators/models.py
ktransformers/operators/models.py
+2
-1
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+3
-3
ktransformers/operators/triton_attention_prefill.py
ktransformers/operators/triton_attention_prefill.py
+206
-0
ktransformers/util/vendors.py
ktransformers/util/vendors.py
+202
-0
No files found.
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
086a9d1c
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include <torch/torch.h>
#include <torch/torch.h>
#include <cstdint>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
typedef
hip_bfloat16
nv_bfloat16
;
__global__
void
dequantize_q8_0_fp32_kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
int
num_blocks
)
{
__global__
void
dequantize_q8_0_fp32_kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
ktransformers/local_chat.py
View file @
086a9d1c
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
custom_models
=
{
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
@@ -169,7 +170,7 @@ def local_chat(
...
@@ -169,7 +170,7 @@ def local_chat(
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
and
device_manager
.
gpu_vendor
==
GPUVendor
.
NVIDIA
:
generated
=
prefill_and_generate
(
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
...
...
ktransformers/operators/attention.py
View file @
086a9d1c
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import
logging
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_func
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
try
:
from
flash_attn
import
flash_attn_func
except
:
pass
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention_prefill
import
context_attention_fwd
import
os
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
if
flashinfer_enabled
:
...
@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
kv_seq_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states
=
value_states
.
view
(
bsz
,
kv_seq_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
# for bsz = 1
query_states
,
attn_output
=
torch
.
zeros
(
bsz
*
q_len
,
self
.
num_heads
,
self
.
v_head_dim
,
device
=
hidden_states
.
device
)
key_states
,
b_start_loc
=
torch
.
zeros
(
bsz
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
value_states_padded
,
b_seq_len
=
torch
.
full
((
bsz
,),
q_len
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
max_input_len
=
q_len
context_attention_fwd
(
q
=
query_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
k
=
key_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
v
=
value_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
),
o
=
attn_output
,
b_start_loc
=
b_start_loc
,
b_seq_len
=
b_seq_len
,
max_input_len
=
max_input_len
,
is_causal
=
True
)
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
[:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
...
@@ -589,7 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -589,7 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
print
(
"for Windows or GPU before ampere, use forward_windows"
)
return
self
.
forward_windows
(
return
self
.
forward_windows
(
hidden_states
,
hidden_states
,
...
...
ktransformers/operators/dynamic_attention.py
View file @
086a9d1c
...
@@ -17,7 +17,10 @@ import logging
...
@@ -17,7 +17,10 @@ import logging
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
except
:
print
(
"falsh attn not found"
)
import
math
import
math
...
...
ktransformers/operators/models.py
View file @
086a9d1c
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer
,
DeepseekV2DecoderLayer
,
DeepseekV2MoE
,
DeepseekV2MoE
,
)
)
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
...
@@ -649,7 +650,7 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -649,7 +650,7 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
if
per_layer_prefill_flag
:
causal_mask
=
None
causal_mask
=
None
else
:
else
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
print
(
"for Windows or GPU before ampere, use forward_windows"
)
# only use mask in forward windows or can't flash attn
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
causal_mask
=
self
.
_update_causal_mask
(
...
...
ktransformers/operators/triton_attention.py
View file @
086a9d1c
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
# Tanh is just a scaled sigmoid
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
# TODO: support hip
#
if
is_hip_
and Lk >= 576:
if
device_manager
.
gpu_vendor
==
GPUVendor
.
AMD
and
Lk
>=
576
:
#
BLOCK = 16
BLOCK
=
16
if
Lk
==
576
:
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
...
...
ktransformers/operators/triton_attention_prefill.py
0 → 100644
View file @
086a9d1c
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
import
triton
import
triton.language
as
tl
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
Out
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
end_n
=
(
cur_batch_seq_len
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
cur_batch_seq_len
)
)
for
start_n
in
range
(
0
,
block_mask
*
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
)
&
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:])),
0
,
float
(
"-inf"
),
)
else
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
,
0
,
float
(
"-inf"
)
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
8
:
BLOCK
=
128
else
:
BLOCK
=
64
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
),
BLOCK_N
=
BLOCK
,
IS_CAUSAL
=
is_causal
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lk
=
Lk
,
)
\ No newline at end of file
ktransformers/util/vendors.py
0 → 100644
View file @
086a9d1c
from
__future__
import
annotations
from
enum
import
IntEnum
,
auto
from
typing
import
Optional
,
Union
,
List
import
torch
class
GPUVendor
(
IntEnum
):
NVIDIA
=
auto
()
AMD
=
auto
()
MooreThreads
=
auto
()
MetaX
=
auto
()
MUSA
=
auto
()
Unknown
=
auto
()
class
DeviceManager
:
"""
Device manager that provides a unified interface for handling different GPU vendors
"""
def
__init__
(
self
):
self
.
gpu_vendor
=
self
.
_detect_gpu_vendor
()
self
.
available_devices
=
self
.
_get_available_devices
()
def
_detect_gpu_vendor
(
self
)
->
GPUVendor
:
"""Detect GPU vendor type"""
if
not
torch
.
cuda
.
is_available
():
# Check MUSA availability (assuming a musa module exists)
try
:
import
musa
if
musa
.
is_available
():
return
GPUVendor
.
MUSA
except
(
ImportError
,
AttributeError
):
pass
return
GPUVendor
.
Unknown
device_name
=
torch
.
cuda
.
get_device_name
(
0
).
lower
()
if
any
(
name
in
device_name
for
name
in
[
"nvidia"
,
"geforce"
,
"quadro"
,
"tesla"
,
"titan"
,
"rtx"
,
"gtx"
]):
return
GPUVendor
.
NVIDIA
elif
any
(
name
in
device_name
for
name
in
[
"amd"
,
"radeon"
,
"rx"
,
"vega"
,
"instinct"
,
"firepro"
,
"mi"
]):
return
GPUVendor
.
AMD
elif
any
(
name
in
device_name
for
name
in
[
"mthreads"
,
"moore"
,
"mtt"
]):
return
GPUVendor
.
MooreThreads
elif
any
(
name
in
device_name
for
name
in
[
"metax"
,
"meta"
]):
return
GPUVendor
.
MetaX
elif
"musa"
in
device_name
:
return
GPUVendor
.
MUSA
# Backend check
try
:
if
hasattr
(
torch
.
version
,
'hip'
)
and
torch
.
version
.
hip
is
not
None
:
return
GPUVendor
.
AMD
elif
hasattr
(
torch
.
version
,
'cuda'
)
and
torch
.
version
.
cuda
is
not
None
:
return
GPUVendor
.
NVIDIA
except
:
pass
return
GPUVendor
.
Unknown
def
_get_available_devices
(
self
)
->
List
[
int
]:
"""Get list of available device indices"""
devices
=
[]
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
devices
=
list
(
range
(
torch
.
cuda
.
device_count
()))
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
devices
=
list
(
range
(
musa
.
device_count
()))
except
(
ImportError
,
AttributeError
):
pass
return
devices
def
get_device_str
(
self
,
device_id
:
Union
[
int
,
str
])
->
str
:
"""
Get device string for the given device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
"""
if
device_id
==
-
1
or
device_id
==
"cpu"
:
return
"cpu"
if
isinstance
(
device_id
,
int
):
if
self
.
gpu_vendor
==
GPUVendor
.
NVIDIA
or
self
.
gpu_vendor
==
GPUVendor
.
AMD
:
if
device_id
<
torch
.
cuda
.
device_count
():
return
f
"cuda:
{
device_id
}
"
elif
self
.
gpu_vendor
==
GPUVendor
.
MUSA
:
try
:
import
musa
if
device_id
<
musa
.
device_count
():
return
f
"musa:
{
device_id
}
"
except
(
ImportError
,
AttributeError
):
pass
return
"cpu"
def
to_torch_device
(
self
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Convert device ID to torch.device object
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
device_str
=
self
.
get_device_str
(
device_id
)
# Handle MUSA device
if
device_str
.
startswith
(
"musa:"
):
try
:
import
musa
index
=
int
(
device_str
.
split
(
":"
)[
-
1
])
return
musa
.
device
(
index
)
except
(
ImportError
,
ValueError
,
AttributeError
):
return
torch
.
device
(
"cpu"
)
# Standard PyTorch device
return
torch
.
device
(
device_str
)
def
move_tensor_to_device
(
self
,
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
device
=
self
.
to_torch_device
(
device_id
)
return
tensor
.
to
(
device
)
def
is_available
(
self
,
index
:
int
=
0
)
->
bool
:
"""
Check if device at specified index is available
Args:
index: Device index to check
Returns:
True if the device is available, False otherwise
"""
if
index
<
0
:
return
True
# CPU is always available
return
index
in
self
.
available_devices
def
get_all_devices
(
self
)
->
List
[
int
]:
"""
Get all available device indices
Returns:
List of available device indices (0, 1, 2, etc.)
"""
return
self
.
available_devices
# Create global device manager instance
device_manager
=
DeviceManager
()
# Convenience functions
def
get_device
(
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
device
:
"""
Get torch.device object for the specified device ID
Args:
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
torch.device object
"""
return
device_manager
.
to_torch_device
(
device_id
)
def
to_device
(
tensor
:
torch
.
Tensor
,
device_id
:
Union
[
int
,
str
]
=
0
)
->
torch
.
Tensor
:
"""
Move tensor to specified device
Args:
tensor: PyTorch tensor to move
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
Returns:
Tensor moved to the specified device
"""
return
device_manager
.
move_tensor_to_device
(
tensor
,
device_id
)
# Get devices
cpu_device
=
get_device
(
-
1
)
# CPU using index -1
cpu_device2
=
get_device
(
"cpu"
)
# CPU using string "cpu"
gpu0
=
get_device
(
0
)
# First GPU
# Move tensors
x
=
torch
.
randn
(
3
,
3
)
x_gpu
=
to_device
(
x
,
0
)
# Move to first GPU
x_cpu1
=
to_device
(
x
,
-
1
)
# Move to CPU using index -1
x_cpu2
=
to_device
(
x
,
"cpu"
)
# Move to CPU using string "cpu"
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment