Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
142fb7ce
Commit
142fb7ce
authored
May 14, 2025
by
rnwang04
Browse files
Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first
parent
333351c7
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
35 deletions
+77
-35
ktransformers/util/utils.py
ktransformers/util/utils.py
+44
-14
setup.py
setup.py
+33
-21
No files found.
ktransformers/util/utils.py
View file @
142fb7ce
...
@@ -27,7 +27,8 @@ from ktransformers.operators import base_operator
...
@@ -27,7 +27,8 @@ from ktransformers.operators import base_operator
from
ktransformers.models.custom_cache
import
StaticCache
from
ktransformers.models.custom_cache
import
StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
if
not
torch
.
xpu
.
is_available
():
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
import
socket
import
socket
warm_uped
=
False
warm_uped
=
False
...
@@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
...
@@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
return
min_compute_capability_major
return
min_compute_capability_major
else
:
else
:
return
torch
.
cuda
.
get_device_properties
(
device
)
return
torch
.
cuda
.
get_device_properties
(
device
)
else
:
return
0
def
set_module
(
model
,
submodule_key
,
module
):
def
set_module
(
model
,
submodule_key
,
module
):
tokens
=
submodule_key
.
split
(
'.'
)
tokens
=
submodule_key
.
split
(
'.'
)
...
@@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
...
@@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list
=
list
(
all_device_list
)
all_device_list
=
list
(
all_device_list
)
return
all_device_list
return
all_device_list
def
load_cur_state_dict
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
:
str
=
""
):
def
load_cur_state_dict
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
:
str
=
""
,
device
=
"cuda"
):
prefix
=
prefix
.
replace
(
"orig_module."
,
""
)
prefix
=
prefix
.
replace
(
"orig_module."
,
""
)
persistent_buffers
=
{
k
:
v
for
k
,
v
in
module
.
_buffers
.
items
()
if
k
not
in
module
.
_non_persistent_buffers_set
}
persistent_buffers
=
{
k
:
v
for
k
,
v
in
module
.
_buffers
.
items
()
if
k
not
in
module
.
_non_persistent_buffers_set
}
local_name_params
=
itertools
.
chain
(
module
.
_parameters
.
items
(),
persistent_buffers
.
items
())
local_name_params
=
itertools
.
chain
(
module
.
_parameters
.
items
(),
persistent_buffers
.
items
())
...
@@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
...
@@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
target_dtype
=
torch
.
get_default_dtype
()
target_dtype
=
torch
.
get_default_dtype
()
device
=
get_device
(
translated_key
[:
translated_key
.
rfind
(
"."
)],
gguf_loader
.
tensor_device_map
)
device
=
get_device
(
translated_key
[:
translated_key
.
rfind
(
"."
)],
gguf_loader
.
tensor_device_map
)
print
(
f
"loading
{
translated_key
}
to
{
device
}
"
)
print
(
f
"loading
{
translated_key
}
to
{
device
}
"
)
torch
.
cuda
.
empty_cache
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
elif
torch
.
xpu
.
is_available
():
torch
.
xpu
.
empty_cache
()
weights
=
load_dequantized_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
weights
=
load_dequantized_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
set_param
(
module
,
name
,
weights
)
set_param
(
module
,
name
,
weights
)
del
weights
del
weights
...
@@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
...
@@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
#print(load_config.tensor_file_map.keys())
#print(load_config.tensor_file_map.keys())
raise
Exception
(
f
"can't find
{
translated_key
}
in GGUF file!"
)
raise
Exception
(
f
"can't find
{
translated_key
}
in GGUF file!"
)
def
load_weights
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
=
''
):
def
sync_all_device
(
all_device_list
):
for
device
in
all_device_list
:
if
"cuda"
in
device
.
lower
():
torch
.
cuda
.
synchronize
(
device
)
elif
"xpu"
in
device
.
lower
():
torch
.
xpu
.
synchronize
(
device
)
else
:
raise
RuntimeError
(
"The device {} is not available"
.
format
(
device
))
torch_device_mapping
=
{
"cuda"
:
"cuda:0"
,
"xpu"
:
"xpu:0"
}
def
load_weights
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
=
''
,
device
=
"cuda"
):
#print(f"recursively loading weights {prefix}")
#print(f"recursively loading weights {prefix}")
if
not
isinstance
(
module
,
base_operator
.
BaseInjectedModule
):
if
not
isinstance
(
module
,
base_operator
.
BaseInjectedModule
):
load_cur_state_dict
(
module
,
gguf_loader
,
prefix
)
load_cur_state_dict
(
module
,
gguf_loader
,
prefix
,
device
=
device
)
for
name
,
child
in
module
.
_modules
.
items
():
for
name
,
child
in
module
.
_modules
.
items
():
load_weights
(
child
,
gguf_loader
,
prefix
+
name
+
"."
)
load_weights
(
child
,
gguf_loader
,
prefix
+
name
+
"."
,
device
=
device
)
else
:
else
:
module
.
load
()
module
.
load
()
...
@@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch
.
_dynamo
.
config
.
suppress_errors
=
True
torch
.
_dynamo
.
config
.
suppress_errors
=
True
batch_size
,
seq_length
=
inputs
.
shape
batch_size
,
seq_length
=
inputs
.
shape
device_map
=
model
.
gguf_loader
.
tensor_device_map
device_map
=
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
'
blk
.0.self_attn'
,
device_map
)
torch_device
=
get_device
(
'
model.layers
.0.self_attn'
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
torch_device
=
torch_device_mapping
[
torch_device
]
if
torch_device
in
torch_device_mapping
else
torch_device
inputs
=
inputs
.
to
(
torch_device
)
inputs
=
inputs
.
to
(
torch_device
)
all_cuda_device
=
get_all_used_cuda_device
(
device_map
)
all_cuda_device
=
get_all_used_cuda_device
(
device_map
)
...
@@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
else
:
else
:
# custom_stream = torch.cuda.Stream()
# custom_stream = torch.cuda.Stream()
torch
.
cuda
.
set_device
(
torch_device
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
torch_device
)
elif
torch
.
xpu
.
is_available
():
torch
.
xpu
.
set_device
(
torch_device
)
else
:
RuntimeError
(
"The device: {torch_device} is not available"
)
inputs_embeds
=
model
.
model
.
embed_tokens
(
cur_token
.
to
(
"cpu"
)).
to
(
torch_device
)
inputs_embeds
=
model
.
model
.
embed_tokens
(
cur_token
.
to
(
"cpu"
)).
to
(
torch_device
)
# with torch.cuda.stream(custom_stream):
# with torch.cuda.stream(custom_stream):
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
...
@@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position
=
cache_position
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
]
return_dict
=
False
,
use_cache
=
True
)[
0
]
if
past_key_values
!=
None
:
if
past_key_values
!=
None
and
isinstance
(
past_key_values
,
StaticCache
)
:
past_key_values
.
change_seq_length
(
1
)
past_key_values
.
change_seq_length
(
1
)
for
device
in
all_cuda_device
:
sync_all_device
(
all_cuda_device
)
torch
.
cuda
.
synchronize
(
device
)
#print(logits)
#print(logits)
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
if
generation_config
.
do_sample
:
...
@@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
return
logits
return
logits
torch
.
cuda
.
set_device
(
torch_device
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
torch_device
)
elif
torch
.
xpu
.
is_available
():
torch
.
xpu
.
set_device
(
torch_device
)
else
:
RuntimeError
(
"The device: {torch_device} is not available"
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
if
torch
.
xpu
.
is_available
():
from
ipex_llm.transformers.kv
import
DynamicUnbalancedFp8Cache
past_key_values
=
DynamicUnbalancedFp8Cache
.
from_legacy_cache
(
None
)
elif
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
)
...
...
setup.py
View file @
142fb7ce
...
@@ -39,7 +39,8 @@ try:
...
@@ -39,7 +39,8 @@ try:
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
except
ImportError
:
except
ImportError
:
MUSA_HOME
=
None
MUSA_HOME
=
None
KTRANSFORMERS_BUILD_XPU
=
torch
.
xpu
.
is_available
()
with_balance
=
os
.
environ
.
get
(
"USE_BALANCE_SERVE"
,
"0"
)
==
"1"
with_balance
=
os
.
environ
.
get
(
"USE_BALANCE_SERVE"
,
"0"
)
==
"1"
class
CpuInstructInfo
:
class
CpuInstructInfo
:
...
@@ -225,6 +226,8 @@ class VersionInfo:
...
@@ -225,6 +226,8 @@ class VersionInfo:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
elif
ROCM_HOME
is
not
None
:
elif
ROCM_HOME
is
not
None
:
backend_version
=
f
"rocm
{
self
.
get_rocm_bare_metal_version
(
ROCM_HOME
)
}
"
backend_version
=
f
"rocm
{
self
.
get_rocm_bare_metal_version
(
ROCM_HOME
)
}
"
elif
torch
.
xpu
.
is_available
():
backend_version
=
f
"xpu"
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set."
)
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
...
@@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension):
...
@@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension):
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
elif
ROCM_HOME
is
not
None
:
elif
ROCM_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_ROCM=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_ROCM=ON"
]
elif
KTRANSFORMERS_BUILD_XPU
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_XPU=ON"
,
"-DKTRANSFORMERS_USE_CUDA=OFF"
]
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set."
)
...
@@ -620,29 +625,36 @@ elif MUSA_HOME is not None:
...
@@ -620,29 +625,36 @@ elif MUSA_HOME is not None:
]
]
}
}
)
)
elif
torch
.
xpu
.
is_available
():
#XPUExtension is not available now.
ops_module
=
None
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
ext_modules
=
[
if
not
torch
.
xpu
.
is_available
():
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
ext_modules
=
[
ops_module
,
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
CUDAExtension
(
ops_module
,
'vLLMMarlin'
,
[
CUDAExtension
(
'csrc/custom_marlin/binding.cpp'
,
'vLLMMarlin'
,
[
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu'
,
'csrc/custom_marlin/binding.cpp'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu'
,
],
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu'
,
extra_compile_args
=
{
],
'cxx'
:
[
'-O3'
],
extra_compile_args
=
{
'nvcc'
:
[
'-O3'
,
'-Xcompiler'
,
'-fPIC'
],
'cxx'
:
[
'-O3'
],
},
'nvcc'
:
[
'-O3'
,
'-Xcompiler'
,
'-fPIC'
],
)
},
]
)
if
with_balance
:
]
print
(
"using balance_serve"
)
if
with_balance
:
ext_modules
.
append
(
print
(
"using balance_serve"
)
CMakeExtension
(
"balance_serve"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"balance_serve"
))
ext_modules
.
append
(
)
CMakeExtension
(
"balance_serve"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"balance_serve"
))
)
else
:
ext_modules
=
[
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
]
setup
(
setup
(
name
=
VersionInfo
.
PACKAGE_NAME
,
name
=
VersionInfo
.
PACKAGE_NAME
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment