Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
f7ee993f
Unverified
Commit
f7ee993f
authored
May 14, 2025
by
aubreyli
Committed by
GitHub
May 14, 2025
Browse files
Merge pull request #1295 from rnwang04/xpu_support
Enable ktransformers on Intel GPU with local chat backend
parents
333351c7
142fb7ce
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
35 deletions
+77
-35
ktransformers/util/utils.py
ktransformers/util/utils.py
+44
-14
setup.py
setup.py
+33
-21
No files found.
ktransformers/util/utils.py
View file @
f7ee993f
...
@@ -27,7 +27,8 @@ from ktransformers.operators import base_operator
...
@@ -27,7 +27,8 @@ from ktransformers.operators import base_operator
from
ktransformers.models.custom_cache
import
StaticCache
from
ktransformers.models.custom_cache
import
StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.util.textstream
import
TextStreamer
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
if
not
torch
.
xpu
.
is_available
():
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
import
socket
import
socket
warm_uped
=
False
warm_uped
=
False
...
@@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
...
@@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
return
min_compute_capability_major
return
min_compute_capability_major
else
:
else
:
return
torch
.
cuda
.
get_device_properties
(
device
)
return
torch
.
cuda
.
get_device_properties
(
device
)
else
:
return
0
def
set_module
(
model
,
submodule_key
,
module
):
def
set_module
(
model
,
submodule_key
,
module
):
tokens
=
submodule_key
.
split
(
'.'
)
tokens
=
submodule_key
.
split
(
'.'
)
...
@@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
...
@@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list
=
list
(
all_device_list
)
all_device_list
=
list
(
all_device_list
)
return
all_device_list
return
all_device_list
def
load_cur_state_dict
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
:
str
=
""
):
def
load_cur_state_dict
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
:
str
=
""
,
device
=
"cuda"
):
prefix
=
prefix
.
replace
(
"orig_module."
,
""
)
prefix
=
prefix
.
replace
(
"orig_module."
,
""
)
persistent_buffers
=
{
k
:
v
for
k
,
v
in
module
.
_buffers
.
items
()
if
k
not
in
module
.
_non_persistent_buffers_set
}
persistent_buffers
=
{
k
:
v
for
k
,
v
in
module
.
_buffers
.
items
()
if
k
not
in
module
.
_non_persistent_buffers_set
}
local_name_params
=
itertools
.
chain
(
module
.
_parameters
.
items
(),
persistent_buffers
.
items
())
local_name_params
=
itertools
.
chain
(
module
.
_parameters
.
items
(),
persistent_buffers
.
items
())
...
@@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
...
@@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
target_dtype
=
torch
.
get_default_dtype
()
target_dtype
=
torch
.
get_default_dtype
()
device
=
get_device
(
translated_key
[:
translated_key
.
rfind
(
"."
)],
gguf_loader
.
tensor_device_map
)
device
=
get_device
(
translated_key
[:
translated_key
.
rfind
(
"."
)],
gguf_loader
.
tensor_device_map
)
print
(
f
"loading
{
translated_key
}
to
{
device
}
"
)
print
(
f
"loading
{
translated_key
}
to
{
device
}
"
)
torch
.
cuda
.
empty_cache
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
elif
torch
.
xpu
.
is_available
():
torch
.
xpu
.
empty_cache
()
weights
=
load_dequantized_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
weights
=
load_dequantized_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
set_param
(
module
,
name
,
weights
)
set_param
(
module
,
name
,
weights
)
del
weights
del
weights
...
@@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
...
@@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
#print(load_config.tensor_file_map.keys())
#print(load_config.tensor_file_map.keys())
raise
Exception
(
f
"can't find
{
translated_key
}
in GGUF file!"
)
raise
Exception
(
f
"can't find
{
translated_key
}
in GGUF file!"
)
def
load_weights
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
=
''
):
def
sync_all_device
(
all_device_list
):
for
device
in
all_device_list
:
if
"cuda"
in
device
.
lower
():
torch
.
cuda
.
synchronize
(
device
)
elif
"xpu"
in
device
.
lower
():
torch
.
xpu
.
synchronize
(
device
)
else
:
raise
RuntimeError
(
"The device {} is not available"
.
format
(
device
))
torch_device_mapping
=
{
"cuda"
:
"cuda:0"
,
"xpu"
:
"xpu:0"
}
def
load_weights
(
module
:
nn
.
Module
,
gguf_loader
:
ModelLoader
,
prefix
=
''
,
device
=
"cuda"
):
#print(f"recursively loading weights {prefix}")
#print(f"recursively loading weights {prefix}")
if
not
isinstance
(
module
,
base_operator
.
BaseInjectedModule
):
if
not
isinstance
(
module
,
base_operator
.
BaseInjectedModule
):
load_cur_state_dict
(
module
,
gguf_loader
,
prefix
)
load_cur_state_dict
(
module
,
gguf_loader
,
prefix
,
device
=
device
)
for
name
,
child
in
module
.
_modules
.
items
():
for
name
,
child
in
module
.
_modules
.
items
():
load_weights
(
child
,
gguf_loader
,
prefix
+
name
+
"."
)
load_weights
(
child
,
gguf_loader
,
prefix
+
name
+
"."
,
device
=
device
)
else
:
else
:
module
.
load
()
module
.
load
()
...
@@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch
.
_dynamo
.
config
.
suppress_errors
=
True
torch
.
_dynamo
.
config
.
suppress_errors
=
True
batch_size
,
seq_length
=
inputs
.
shape
batch_size
,
seq_length
=
inputs
.
shape
device_map
=
model
.
gguf_loader
.
tensor_device_map
device_map
=
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
'
blk
.0.self_attn'
,
device_map
)
torch_device
=
get_device
(
'
model.layers
.0.self_attn'
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
torch_device
=
torch_device_mapping
[
torch_device
]
if
torch_device
in
torch_device_mapping
else
torch_device
inputs
=
inputs
.
to
(
torch_device
)
inputs
=
inputs
.
to
(
torch_device
)
all_cuda_device
=
get_all_used_cuda_device
(
device_map
)
all_cuda_device
=
get_all_used_cuda_device
(
device_map
)
...
@@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
logits
=
cuda_graph_runner
(
cur_token
,
position_ids
,
cache_position
)
else
:
else
:
# custom_stream = torch.cuda.Stream()
# custom_stream = torch.cuda.Stream()
torch
.
cuda
.
set_device
(
torch_device
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
torch_device
)
elif
torch
.
xpu
.
is_available
():
torch
.
xpu
.
set_device
(
torch_device
)
else
:
RuntimeError
(
"The device: {torch_device} is not available"
)
inputs_embeds
=
model
.
model
.
embed_tokens
(
cur_token
.
to
(
"cpu"
)).
to
(
torch_device
)
inputs_embeds
=
model
.
model
.
embed_tokens
(
cur_token
.
to
(
"cpu"
)).
to
(
torch_device
)
# with torch.cuda.stream(custom_stream):
# with torch.cuda.stream(custom_stream):
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
...
@@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position
=
cache_position
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
]
return_dict
=
False
,
use_cache
=
True
)[
0
]
if
past_key_values
!=
None
:
if
past_key_values
!=
None
and
isinstance
(
past_key_values
,
StaticCache
)
:
past_key_values
.
change_seq_length
(
1
)
past_key_values
.
change_seq_length
(
1
)
for
device
in
all_cuda_device
:
sync_all_device
(
all_cuda_device
)
torch
.
cuda
.
synchronize
(
device
)
#print(logits)
#print(logits)
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
if
generation_config
.
do_sample
:
...
@@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
return
logits
return
logits
torch
.
cuda
.
set_device
(
torch_device
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
torch_device
)
elif
torch
.
xpu
.
is_available
():
torch
.
xpu
.
set_device
(
torch_device
)
else
:
RuntimeError
(
"The device: {torch_device} is not available"
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
if
torch
.
xpu
.
is_available
():
from
ipex_llm.transformers.kv
import
DynamicUnbalancedFp8Cache
past_key_values
=
DynamicUnbalancedFp8Cache
.
from_legacy_cache
(
None
)
elif
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
)
...
...
setup.py
View file @
f7ee993f
...
@@ -39,7 +39,8 @@ try:
...
@@ -39,7 +39,8 @@ try:
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
except
ImportError
:
except
ImportError
:
MUSA_HOME
=
None
MUSA_HOME
=
None
KTRANSFORMERS_BUILD_XPU
=
torch
.
xpu
.
is_available
()
with_balance
=
os
.
environ
.
get
(
"USE_BALANCE_SERVE"
,
"0"
)
==
"1"
with_balance
=
os
.
environ
.
get
(
"USE_BALANCE_SERVE"
,
"0"
)
==
"1"
class
CpuInstructInfo
:
class
CpuInstructInfo
:
...
@@ -225,6 +226,8 @@ class VersionInfo:
...
@@ -225,6 +226,8 @@ class VersionInfo:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
elif
ROCM_HOME
is
not
None
:
elif
ROCM_HOME
is
not
None
:
backend_version
=
f
"rocm
{
self
.
get_rocm_bare_metal_version
(
ROCM_HOME
)
}
"
backend_version
=
f
"rocm
{
self
.
get_rocm_bare_metal_version
(
ROCM_HOME
)
}
"
elif
torch
.
xpu
.
is_available
():
backend_version
=
f
"xpu"
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set."
)
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
...
@@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension):
...
@@ -495,6 +498,8 @@ class CMakeBuild(BuildExtension):
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
elif
ROCM_HOME
is
not
None
:
elif
ROCM_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_ROCM=ON"
]
cmake_args
+=
[
"-DKTRANSFORMERS_USE_ROCM=ON"
]
elif
KTRANSFORMERS_BUILD_XPU
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_XPU=ON"
,
"-DKTRANSFORMERS_USE_CUDA=OFF"
]
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set."
)
...
@@ -620,29 +625,36 @@ elif MUSA_HOME is not None:
...
@@ -620,29 +625,36 @@ elif MUSA_HOME is not None:
]
]
}
}
)
)
elif
torch
.
xpu
.
is_available
():
#XPUExtension is not available now.
ops_module
=
None
else
:
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
ext_modules
=
[
if
not
torch
.
xpu
.
is_available
():
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
ext_modules
=
[
ops_module
,
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
CUDAExtension
(
ops_module
,
'vLLMMarlin'
,
[
CUDAExtension
(
'csrc/custom_marlin/binding.cpp'
,
'vLLMMarlin'
,
[
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu'
,
'csrc/custom_marlin/binding.cpp'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu'
,
'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu'
,
],
'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu'
,
extra_compile_args
=
{
],
'cxx'
:
[
'-O3'
],
extra_compile_args
=
{
'nvcc'
:
[
'-O3'
,
'-Xcompiler'
,
'-fPIC'
],
'cxx'
:
[
'-O3'
],
},
'nvcc'
:
[
'-O3'
,
'-Xcompiler'
,
'-fPIC'
],
)
},
]
)
if
with_balance
:
]
print
(
"using balance_serve"
)
if
with_balance
:
ext_modules
.
append
(
print
(
"using balance_serve"
)
CMakeExtension
(
"balance_serve"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"balance_serve"
))
ext_modules
.
append
(
)
CMakeExtension
(
"balance_serve"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"balance_serve"
))
)
else
:
ext_modules
=
[
CMakeExtension
(
"cpuinfer_ext"
,
os
.
fspath
(
Path
(
""
).
resolve
()
/
"csrc"
/
"ktransformers_ext"
)),
]
setup
(
setup
(
name
=
VersionInfo
.
PACKAGE_NAME
,
name
=
VersionInfo
.
PACKAGE_NAME
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment