Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
25c5bddd
Unverified
Commit
25c5bddd
authored
Feb 20, 2025
by
Azure
Committed by
GitHub
Feb 20, 2025
Browse files
Merge pull request #506 from makllama/musa
feat: Support Moore Threads GPU
parents
1dd84b4a
2207f6cd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
145 additions
and
34 deletions
+145
-34
.gitignore
.gitignore
+1
-0
ktransformers/ktransformers_ext/CMakeLists.txt
ktransformers/ktransformers_ext/CMakeLists.txt
+35
-5
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
+5
-1
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
...nsformers/ktransformers_ext/cpu_backend/vendors/README.md
+3
-0
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
+3
-0
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
+7
-0
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+6
-2
setup.py
setup.py
+85
-26
No files found.
.gitignore
View file @
25c5bddd
...
@@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
...
@@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
mmlu_result_q4km.json
mmlu_result_q4km.json
mmlu_result_q4km.log
mmlu_result_q4km.log
ktransformers/tests/mmlu_result_silicon.log
ktransformers/tests/mmlu_result_silicon.log
ktransformers/ktransformers_ext/cuda_musa/
ktransformers/ktransformers_ext/CMakeLists.txt
View file @
25c5bddd
...
@@ -30,6 +30,8 @@ if (NOT MSVC)
...
@@ -30,6 +30,8 @@ if (NOT MSVC)
option
(
LLAMA_F16C
"llama: enable F16C"
OFF
)
option
(
LLAMA_F16C
"llama: enable F16C"
OFF
)
endif
()
endif
()
option
(
LLAMA_AVX512_FANCY_SIMD
"llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI"
OFF
)
option
(
LLAMA_AVX512_FANCY_SIMD
"llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI"
OFF
)
option
(
KTRANSFORMERS_USE_CUDA
"ktransformers: use CUDA"
OFF
)
option
(
KTRANSFORMERS_USE_MUSA
"ktransformers: use MUSA"
OFF
)
# Architecture specific
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# TODO: probably these flags need to be tweaked on some architectures
...
@@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
...
@@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if
(
WIN32
)
if
(
WIN32
)
include_directories
(
"$ENV{CUDA_PATH}/include"
)
include_directories
(
"$ENV{CUDA_PATH}/include"
)
elseif
(
UNIX
)
elseif
(
UNIX
)
find_package
(
CUDA REQUIRED
)
if
(
KTRANSFORMERS_USE_CUDA
)
include_directories
(
"
${
CUDA_INCLUDE_DIRS
}
"
)
find_package
(
CUDA REQUIRED
)
include_directories
(
"
${
CUDA_INCLUDE_DIRS
}
"
)
add_compile_definitions
(
KTRANSFORMERS_USE_CUDA=1
)
endif
()
if
(
KTRANSFORMERS_USE_MUSA
)
if
(
NOT EXISTS $ENV{MUSA_PATH}
)
if
(
NOT EXISTS /opt/musa
)
set
(
MUSA_PATH /usr/local/musa
)
else
()
set
(
MUSA_PATH /opt/musa
)
endif
()
else
()
set
(
MUSA_PATH $ENV{MUSA_PATH}
)
endif
()
list
(
APPEND CMAKE_MODULE_PATH
"
${
MUSA_PATH
}
/cmake"
)
find_package
(
MUSAToolkit
)
if
(
MUSAToolkit_FOUND
)
message
(
STATUS
"MUSA Toolkit found"
)
add_compile_definitions
(
KTRANSFORMERS_USE_MUSA=1
)
endif
()
endif
()
endif
()
endif
()
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
SOURCE_DIR1
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
SOURCE_DIR1
)
...
@@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
...
@@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if
(
WIN32
)
if
(
WIN32
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_PATH}/lib/x64/cudart.lib"
)
#CUDA::cudart
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_PATH}/lib/x64/cudart.lib"
)
#CUDA::cudart
elseif
(
UNIX
)
elseif
(
UNIX
)
if
(
NOT DEFINED ENV{CUDA_HOME} OR
"$ENV{CUDA_HOME}"
STREQUAL
""
)
if
(
KTRANSFORMERS_USE_CUDA
)
set
(
ENV{CUDA_HOME}
"/usr/local/cuda"
)
if
(
NOT DEFINED ENV{CUDA_HOME} OR
"$ENV{CUDA_HOME}"
STREQUAL
""
)
set
(
ENV{CUDA_HOME}
"/usr/local/cuda"
)
endif
()
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_HOME}/lib64/libcudart.so"
)
endif
()
if
(
KTRANSFORMERS_USE_MUSA
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE MUSA::musart
)
endif
()
endif
()
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_HOME}/lib64/libcudart.so"
)
endif
()
endif
()
# Define the USE_NUMA option
# Define the USE_NUMA option
...
...
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
View file @
25c5bddd
...
@@ -17,7 +17,11 @@
...
@@ -17,7 +17,11 @@
#include <queue>
#include <queue>
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#include "cuda_runtime.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#endif
#include "backend.h"
#include "backend.h"
#include "task_queue.h"
#include "task_queue.h"
...
...
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
0 → 100644
View file @
25c5bddd
## TODO
This directory can be removed after updating the version of
`llama.cpp`
.
\ No newline at end of file
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
0 → 100644
View file @
25c5bddd
#pragma once
#include <cuda_runtime.h>
\ No newline at end of file
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
0 → 100644
View file @
25c5bddd
#pragma once
#include <musa_runtime.h>
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
25c5bddd
/**
/**
* @Description :
* @Description :
* @Author : Azure-Tang
* @Author : Azure-Tang
* @Date : 2024-07-25 13:38:30
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#include "custom_gguf/ops.h"
#include "custom_gguf/ops.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "gptq_marlin/ops.h"
#include "gptq_marlin/ops.h"
#endif
// Python bindings
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl.h>
...
@@ -33,8 +35,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
...
@@ -33,8 +35,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
#ifdef KTRANSFORMERS_USE_CUDA
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"size_n"
),
py
::
arg
(
"size_k"
),
py
::
arg
(
"is_k_full"
));
py
::
arg
(
"size_n"
),
py
::
arg
(
"size_k"
),
py
::
arg
(
"is_k_full"
));
#endif
}
}
setup.py
View file @
25c5bddd
#!/usr/bin/env python
#!/usr/bin/env python
# coding=utf-8
# coding=utf-8
'''
'''
Description :
Description :
Author : chenxl
Author : chenxl
Date : 2024-07-27 16:15:27
Date : 2024-07-27 16:15:27
Version : 1.0.0
Version : 1.0.0
LastEditors : chenxl
LastEditors : chenxl
LastEditTime : 2024-08-14 16:36:19
LastEditTime : 2024-08-14 16:36:19
Adapted from:
Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao.
Copyright (c) 2023, Tri Dao.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
'''
import
os
import
os
...
@@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
...
@@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from
setuptools
import
setup
,
Extension
from
setuptools
import
setup
,
Extension
from
cpufeature.extension
import
CPUFeature
from
cpufeature.extension
import
CPUFeature
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
try
:
from
torch_musa.utils.simple_porting
import
SimplePorting
from
torch_musa.utils.musa_extension
import
BuildExtension
,
MUSAExtension
,
MUSA_HOME
except
ImportError
:
MUSA_HOME
=
None
class
CpuInstructInfo
:
class
CpuInstructInfo
:
CPU_INSTRUCT
=
os
.
getenv
(
"CPU_INSTRUCT"
,
"NATIVE"
)
CPU_INSTRUCT
=
os
.
getenv
(
"CPU_INSTRUCT"
,
"NATIVE"
)
...
@@ -40,7 +45,7 @@ class CpuInstructInfo:
...
@@ -40,7 +45,7 @@ class CpuInstructInfo:
CMAKE_FANCY
=
"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
CMAKE_FANCY
=
"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
CMAKE_AVX512
=
"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
CMAKE_AVX512
=
"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
CMAKE_AVX2
=
"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
CMAKE_AVX2
=
"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
class
VersionInfo
:
class
VersionInfo
:
THIS_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
THIS_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
PACKAGE_NAME
=
"ktransformers"
PACKAGE_NAME
=
"ktransformers"
...
@@ -49,6 +54,16 @@ class VersionInfo:
...
@@ -49,6 +54,16 @@ class VersionInfo:
)
)
FORCE_BUILD
=
os
.
getenv
(
"KTRANSFORMERS_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
FORCE_BUILD
=
os
.
getenv
(
"KTRANSFORMERS_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
def
get_musa_bare_metal_version
(
self
,
musa_dir
):
raw_output
=
subprocess
.
run
(
[
musa_dir
+
"/bin/mcc"
,
"-v"
],
check
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
).
stdout
.
decode
(
"utf-8"
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"version"
)
+
1
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
musa_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
return
musa_version
def
get_cuda_bare_metal_version
(
self
,
cuda_dir
):
def
get_cuda_bare_metal_version
(
self
,
cuda_dir
):
raw_output
=
subprocess
.
check_output
(
raw_output
=
subprocess
.
check_output
(
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
[
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
...
@@ -58,7 +73,7 @@ class VersionInfo:
...
@@ -58,7 +73,7 @@ class VersionInfo:
cuda_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
cuda_version
=
f
"
{
bare_metal_version
.
major
}{
bare_metal_version
.
minor
}
"
return
cuda_version
return
cuda_version
def
get_cuda_version_of_torch
(
self
,
):
def
get_cuda_version_of_torch
(
self
):
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
return
cuda_version
return
cuda_version
...
@@ -117,7 +132,7 @@ class VersionInfo:
...
@@ -117,7 +132,7 @@ class VersionInfo:
torch_version_raw
=
parse
(
torch
.
__version__
)
torch_version_raw
=
parse
(
torch
.
__version__
)
torch_version
=
f
"
{
torch_version_raw
.
major
}{
torch_version_raw
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}{
torch_version_raw
.
minor
}
"
return
torch_version
return
torch_version
def
get_flash_version
(
self
,):
def
get_flash_version
(
self
,):
version_file
=
os
.
path
.
join
(
version_file
=
os
.
path
.
join
(
Path
(
VersionInfo
.
THIS_DIR
),
VersionInfo
.
PACKAGE_NAME
,
"__init__.py"
)
Path
(
VersionInfo
.
THIS_DIR
),
VersionInfo
.
PACKAGE_NAME
,
"__init__.py"
)
...
@@ -128,12 +143,21 @@ class VersionInfo:
...
@@ -128,12 +143,21 @@ class VersionInfo:
return
flash_version
return
flash_version
def
get_package_version
(
self
,
full_version
=
False
):
def
get_package_version
(
self
,
full_version
=
False
):
flash_version
=
self
.
get_flash_version
()
flash_version
=
str
(
self
.
get_flash_version
())
package_version
=
f
"
{
str
(
flash_version
)
}
+cu
{
self
.
get_cuda_bare_metal_version
(
CUDA_HOME
)
}
torch
{
self
.
get_torch_version
()
}{
self
.
get_cpu_instruct
()
}
"
torch_version
=
self
.
get_torch_version
()
cpu_instruct
=
self
.
get_cpu_instruct
()
backend_version
=
""
if
CUDA_HOME
is
not
None
:
backend_version
=
f
"cu
{
self
.
get_cuda_bare_metal_version
(
CUDA_HOME
)
}
"
elif
MUSA_HOME
is
not
None
:
backend_version
=
f
"mu
{
self
.
get_musa_bare_metal_version
(
MUSA_HOME
)
}
"
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
package_version
=
f
"
{
flash_version
}
+
{
backend_version
}
torch
{
torch_version
}{
cpu_instruct
}
"
if
full_version
:
if
full_version
:
return
package_version
return
package_version
if
not
VersionInfo
.
FORCE_BUILD
:
if
not
VersionInfo
.
FORCE_BUILD
:
return
str
(
flash_version
)
return
flash_version
return
package_version
return
package_version
...
@@ -218,11 +242,19 @@ class CMakeBuild(BuildExtension):
...
@@ -218,11 +242,19 @@ class CMakeBuild(BuildExtension):
f
"-DPYTHON_EXECUTABLE=
{
sys
.
executable
}
"
,
f
"-DPYTHON_EXECUTABLE=
{
sys
.
executable
}
"
,
f
"-DCMAKE_BUILD_TYPE=
{
cfg
}
"
,
# not used on MSVC, but no harm
f
"-DCMAKE_BUILD_TYPE=
{
cfg
}
"
,
# not used on MSVC, but no harm
]
]
if
CUDA_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_CUDA=ON"
]
elif
MUSA_HOME
is
not
None
:
cmake_args
+=
[
"-DKTRANSFORMERS_USE_MUSA=ON"
]
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
build_args
=
[]
build_args
=
[]
if
"CMAKE_ARGS"
in
os
.
environ
:
if
"CMAKE_ARGS"
in
os
.
environ
:
cmake_args
+=
[
cmake_args
+=
[
item
for
item
in
os
.
environ
[
"CMAKE_ARGS"
].
split
(
" "
)
if
item
]
item
for
item
in
os
.
environ
[
"CMAKE_ARGS"
].
split
(
" "
)
if
item
]
if
CpuInstructInfo
.
CPU_INSTRUCT
==
CpuInstructInfo
.
FANCY
:
if
CpuInstructInfo
.
CPU_INSTRUCT
==
CpuInstructInfo
.
FANCY
:
cpu_args
=
CpuInstructInfo
.
CMAKE_FANCY
cpu_args
=
CpuInstructInfo
.
CMAKE_FANCY
elif
CpuInstructInfo
.
CPU_INSTRUCT
==
CpuInstructInfo
.
AVX512
:
elif
CpuInstructInfo
.
CPU_INSTRUCT
==
CpuInstructInfo
.
AVX512
:
...
@@ -231,7 +263,7 @@ class CMakeBuild(BuildExtension):
...
@@ -231,7 +263,7 @@ class CMakeBuild(BuildExtension):
cpu_args
=
CpuInstructInfo
.
CMAKE_AVX2
cpu_args
=
CpuInstructInfo
.
CMAKE_AVX2
else
:
else
:
cpu_args
=
CpuInstructInfo
.
CMAKE_NATIVE
cpu_args
=
CpuInstructInfo
.
CMAKE_NATIVE
cmake_args
+=
[
cmake_args
+=
[
item
for
item
in
cpu_args
.
split
(
" "
)
if
item
item
for
item
in
cpu_args
.
split
(
" "
)
if
item
]
]
...
@@ -288,28 +320,55 @@ class CMakeBuild(BuildExtension):
...
@@ -288,28 +320,55 @@ class CMakeBuild(BuildExtension):
print
(
"Standard output:"
,
result
.
stdout
)
print
(
"Standard output:"
,
result
.
stdout
)
print
(
"Standard error:"
,
result
.
stderr
)
print
(
"Standard error:"
,
result
.
stderr
)
subprocess
.
run
(
subprocess
.
run
(
[
"cmake"
,
"--build"
,
"."
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
[
"cmake"
,
"--build"
,
"."
,
"--verbose"
,
*
build_args
],
cwd
=
build_temp
,
check
=
True
)
)
if
CUDA_HOME
is
not
None
:
ops_module
=
CUDAExtension
(
'KTransformersOps'
,
[
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'ktransformers/ktransformers_ext/cuda/binding.cpp'
,
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_CUDA'
],
'nvcc'
:
[
'-O3'
,
'--use_fast_math'
,
'-Xcompiler'
,
'-fPIC'
,
'-DKTRANSFORMERS_USE_CUDA'
,
]
}
)
elif
MUSA_HOME
is
not
None
:
SimplePorting
(
cuda_dir_path
=
"ktransformers/ktransformers_ext/cuda"
,
mapping_rule
=
{
# Common rules
"at::cuda"
:
"at::musa"
,
"#include <ATen/cuda/CUDAContext.h>"
:
"#include
\"
torch_musa/csrc/aten/musa/MUSAContext.h
\"
"
,
"#include <c10/cuda/CUDAGuard.h>"
:
"#include
\"
torch_musa/csrc/core/MUSAGuard.h
\"
"
,
}).
run
()
ops_module
=
MUSAExtension
(
'KTransformersOps'
,
[
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu'
,
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp'
,
# TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
extra_compile_args
=
{
'cxx'
:
[
'force_mcc'
],
'mcc'
:
[
'-O3'
,
'-DKTRANSFORMERS_USE_MUSA'
,
'-DTHRUST_IGNORE_CUB_VERSION_CHECK'
,
]
}
)
else
:
raise
ValueError
(
"Unsupported backend: CUDA_HOME and MUSA_HOME are not set."
)
setup
(
setup
(
version
=
VersionInfo
().
get_package_version
(),
version
=
VersionInfo
().
get_package_version
(),
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
cmdclass
=
{
"bdist_wheel"
:
BuildWheelsCommand
,
"build_ext"
:
CMakeBuild
},
ext_modules
=
[
ext_modules
=
[
CMakeExtension
(
"cpuinfer_ext"
),
CMakeExtension
(
"cpuinfer_ext"
),
CUDAExtension
(
'KTransformersOps'
,
[
ops_module
,
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu'
,
'ktransformers/ktransformers_ext/cuda/binding.cpp'
,
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:
[
'-O3'
,
'--use_fast_math'
,
'-Xcompiler'
,
'-fPIC'
,
]
}
)
]
]
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment