Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
ed843741
Commit
ed843741
authored
Mar 14, 2025
by
Azure-Tang
Browse files
merge main; Add torch q8 linear
parent
6c4ed591
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1021 additions
and
99 deletions
+1021
-99
ktransformers/ktransformers_ext/CMakeLists.txt
ktransformers/ktransformers_ext/CMakeLists.txt
+39
-0
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
+80
-76
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
+13
-1
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
+172
-0
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
+131
-3
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
+13
-0
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+1
-0
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
...formers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
+1
-1
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
...ormers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
+1
-1
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
...ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
+5
-0
ktransformers/ktransformers_ext/ext_bindings.cpp
ktransformers/ktransformers_ext/ext_bindings.cpp
+0
-1
ktransformers/ktransformers_ext/vendors/cuda.h
ktransformers/ktransformers_ext/vendors/cuda.h
+15
-0
ktransformers/ktransformers_ext/vendors/hip.h
ktransformers/ktransformers_ext/vendors/hip.h
+172
-0
ktransformers/ktransformers_ext/vendors/musa.h
ktransformers/ktransformers_ext/vendors/musa.h
+137
-0
ktransformers/ktransformers_ext/vendors/vendor.h
ktransformers/ktransformers_ext/vendors/vendor.h
+13
-0
ktransformers/local_chat.py
ktransformers/local_chat.py
+2
-1
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+26
-12
ktransformers/operators/dynamic_attention.py
ktransformers/operators/dynamic_attention.py
+4
-1
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+192
-0
ktransformers/operators/models.py
ktransformers/operators/models.py
+4
-2
No files found.
ktransformers/ktransformers_ext/CMakeLists.txt
View file @
ed843741
...
...
@@ -32,6 +32,7 @@ endif()
option
(
LLAMA_AVX512_FANCY_SIMD
"llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI"
OFF
)
option
(
KTRANSFORMERS_USE_CUDA
"ktransformers: use CUDA"
OFF
)
option
(
KTRANSFORMERS_USE_MUSA
"ktransformers: use MUSA"
OFF
)
option
(
KTRANSFORMERS_USE_ROCM
"ktransformers: use ROCM"
OFF
)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
...
...
@@ -201,6 +202,31 @@ endif()
# message(STATUS "Can't found CUDA lib")
# endif()
if
(
NOT EXISTS $ENV{ROCM_PATH}
)
if
(
NOT EXISTS /opt/rocm
)
set
(
ROCM_PATH /usr
)
else
()
set
(
ROCM_PATH /opt/rocm
)
endif
()
else
()
set
(
ROCM_PATH $ENV{ROCM_PATH}
)
endif
()
list
(
APPEND CMAKE_PREFIX_PATH
${
ROCM_PATH
}
)
list
(
APPEND CMAKE_PREFIX_PATH
"
${
ROCM_PATH
}
/lib64/cmake"
)
if
(
NOT EXISTS $ENV{MUSA_PATH}
)
if
(
NOT EXISTS /opt/musa
)
set
(
MUSA_PATH /usr/local/musa
)
else
()
set
(
MUSA_PATH /opt/musa
)
endif
()
else
()
set
(
MUSA_PATH $ENV{MUSA_PATH}
)
endif
()
list
(
APPEND CMAKE_MODULE_PATH
"
${
MUSA_PATH
}
/cmake"
)
add_compile_options
(
"$<$<COMPILE_LANGUAGE:CXX>:
${
ARCH_FLAGS
}
>"
)
add_compile_options
(
"$<$<COMPILE_LANGUAGE:C>:
${
ARCH_FLAGS
}
>"
)
...
...
@@ -218,6 +244,14 @@ elseif (UNIX)
add_compile_definitions
(
KTRANSFORMERS_USE_CUDA=1
)
endif
()
if
(
KTRANSFORMERS_USE_ROCM
)
find_package
(
HIP REQUIRED
)
if
(
HIP_FOUND
)
include_directories
(
"
${
HIP_INCLUDE_DIRS
}
"
)
add_compile_definitions
(
KTRANSFORMERS_USE_ROCM=1
)
endif
()
endif
()
if
(
KTRANSFORMERS_USE_MUSA
)
if
(
NOT EXISTS $ENV{MUSA_PATH}
)
if
(
NOT EXISTS /opt/musa
)
...
...
@@ -258,6 +292,11 @@ elseif(UNIX)
endif
()
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_HOME}/lib64/libcudart.so"
)
endif
()
if
(
KTRANSFORMERS_USE_ROCM
)
add_compile_definitions
(
USE_HIP=1
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"
${
ROCM_PATH
}
/lib/libamdhip64.so"
)
message
(
STATUS
"Building for HIP"
)
endif
()
if
(
KTRANSFORMERS_USE_MUSA
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE MUSA::musart
)
endif
()
...
...
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
View file @
ed843741
...
...
@@ -7,79 +7,83 @@
* @LastEditTime : 2024-08-07 09:47:43
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#endif
#include "backend.h"
#include "task_queue.h"
#include "llama.cpp/ggml-impl.h"
class
CPUInfer
{
public:
CPUInfer
(
int
thread_num
)
{
backend_
=
new
Backend
(
thread_num
-
1
);
task_queue_
=
new
TaskQueue
();
for
(
int
i
=
0
;
i
<
(
1
<<
16
);
++
i
)
{
ggml_table_f32_f16
[
i
]
=
GGML_COMPUTE_FP16_TO_FP32
(
i
);
}
}
~
CPUInfer
()
{
delete
backend_
;
delete
task_queue_
;
}
template
<
typename
Func
,
typename
Obj
,
typename
...
Args
>
void
enqueue
(
Func
f
,
Obj
*
obj
,
Args
...
args
)
{
task_queue_
->
enqueue
([
=
]()
{
std
::
invoke
(
f
,
*
obj
,
args
...,
backend_
);
});
}
void
submit
(
std
::
pair
<
intptr_t
,
intptr_t
>
params
)
{
void
(
*
func
)(
void
*
)
=
(
void
(
*
)(
void
*
))
params
.
first
;
void
*
args
=
(
void
*
)
params
.
second
;
*
((
CPUInfer
**
)
args
)
=
this
;
func
(
args
);
}
void
sync
()
{
task_queue_
->
sync
();
}
void
submit_with_cuda_stream
(
intptr_t
user_cuda_stream
,
std
::
pair
<
intptr_t
,
intptr_t
>
params
)
{
void
(
*
func
)(
void
*
)
=
(
void
(
*
)(
void
*
))
params
.
first
;
void
*
args
=
(
void
*
)
params
.
second
;
*
((
CPUInfer
**
)
args
)
=
this
;
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
func
,
args
);
}
static
void
sync_
(
void
*
cpu_infer_ptr
)
{
CPUInfer
*
cpuinfer
=
(
CPUInfer
*
)
cpu_infer_ptr
;
cpuinfer
->
sync
();
}
void
sync_with_cuda_stream
(
intptr_t
user_cuda_stream
)
{
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
&
sync_
,
(
void
*
)
this
);
}
public:
Backend
*
backend_
;
TaskQueue
*
task_queue_
;
};
#endif
\ No newline at end of file
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#elif KTRANSFORMERS_USE_ROCM
#define __HIP_PLATFORM_AMD__
#include "vendors/hip.h"
#endif
#include "backend.h"
#include "task_queue.h"
#include "../vendors/vendor.h"
#include "llama.cpp/ggml-impl.h"
class
CPUInfer
{
public:
CPUInfer
(
int
thread_num
)
{
backend_
=
new
Backend
(
thread_num
-
1
);
task_queue_
=
new
TaskQueue
();
for
(
int
i
=
0
;
i
<
(
1
<<
16
);
++
i
)
{
ggml_table_f32_f16
[
i
]
=
GGML_COMPUTE_FP16_TO_FP32
(
i
);
}
}
~
CPUInfer
()
{
delete
backend_
;
delete
task_queue_
;
}
template
<
typename
Func
,
typename
Obj
,
typename
...
Args
>
void
enqueue
(
Func
f
,
Obj
*
obj
,
Args
...
args
)
{
task_queue_
->
enqueue
([
=
]()
{
std
::
invoke
(
f
,
*
obj
,
args
...,
backend_
);
});
}
void
submit
(
std
::
pair
<
intptr_t
,
intptr_t
>
params
)
{
void
(
*
func
)(
void
*
)
=
(
void
(
*
)(
void
*
))
params
.
first
;
void
*
args
=
(
void
*
)
params
.
second
;
*
((
CPUInfer
**
)
args
)
=
this
;
func
(
args
);
}
void
sync
()
{
task_queue_
->
sync
();
}
void
submit_with_cuda_stream
(
intptr_t
user_cuda_stream
,
std
::
pair
<
intptr_t
,
intptr_t
>
params
)
{
void
(
*
func
)(
void
*
)
=
(
void
(
*
)(
void
*
))
params
.
first
;
void
*
args
=
(
void
*
)
params
.
second
;
*
((
CPUInfer
**
)
args
)
=
this
;
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
func
,
args
);
}
static
void
sync_
(
void
*
cpu_infer_ptr
)
{
CPUInfer
*
cpuinfer
=
(
CPUInfer
*
)
cpu_infer_ptr
;
cpuinfer
->
sync
();
}
void
sync_with_cuda_stream
(
intptr_t
user_cuda_stream
)
{
cudaLaunchHostFunc
((
cudaStream_t
)
user_cuda_stream
,
(
cudaHostFn_t
)
&
sync_
,
(
void
*
)
this
);
}
public:
Backend
*
backend_
;
TaskQueue
*
task_queue_
;
};
#endif
\ No newline at end of file
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
View file @
ed843741
#pragma once
#include <cuda_runtime.h>
\ No newline at end of file
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
0 → 100644
View file @
ed843741
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef
hip_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
View file @
ed843741
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t
#define nv_bfloat16 mt_bfloat16
\ No newline at end of file
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef
mt_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
0 → 100644
View file @
ed843741
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
ed843741
...
...
@@ -15,6 +15,7 @@
#include <torch/torch.h>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
typedef
hip_bfloat16
nv_bfloat16
;
__global__
void
dequantize_q8_0_fp32_kernel
(
const
int8_t
*
data
,
float
*
output
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
int
num_blocks
)
{
long
long
global_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu
View file @
ed843741
...
...
@@ -36,7 +36,7 @@ inline std::string str(T x) {
namespace
gptq_marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if
(
defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
) || defined(__HIP_PLATFORM_AMD__)
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
...
...
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
View file @
ed843741
...
...
@@ -39,7 +39,7 @@ using I4 = Vec<int, 4>;
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if
(
defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
) || defined (__HIP_PLATFORM_AMD__)
// No support for async
#else
...
...
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
View file @
ed843741
...
...
@@ -8,6 +8,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
typedef
__hip_bfloat16
nv_bfloat16
;
typedef
__hip_bfloat162
nv_bfloat162
;
#endif
namespace
gptq_marlin
{
template
<
typename
scalar_t
>
...
...
ktransformers/ktransformers_ext/ext_bindings.cpp
View file @
ed843741
...
...
@@ -9,7 +9,6 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
...
...
ktransformers/ktransformers_ext/vendors/cuda.h
0 → 100644
View file @
ed843741
#pragma once
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
ktransformers/ktransformers_ext/vendors/hip.h
0 → 100644
View file @
ed843741
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef
hip_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/vendors/musa.h
0 → 100644
View file @
ed843741
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef
mt_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/vendors/vendor.h
0 → 100644
View file @
ed843741
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
ktransformers/local_chat.py
View file @
ed843741
...
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
...
@@ -169,7 +170,7 @@ def local_chat(
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
and
device_manager
.
gpu_vendor
==
GPUVendor
.
NVIDIA
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
...
...
ktransformers/operators/attention.py
View file @
ed843741
...
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_func
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
try
:
from
flash_attn
import
flash_attn_func
except
:
pass
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention_prefill
import
context_attention_fwd
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
...
...
@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states
[:,
:,
:,
self
.
qk_nope_head_dim
:]
=
k_pe
.
view
(
bsz
,
kv_seq_len
,
1
,
-
1
)
value_states
=
value_states
.
view
(
bsz
,
kv_seq_len
,
self
.
num_heads
,
self
.
v_head_dim
)
value_states_padded
=
torch
.
nn
.
functional
.
pad
(
value_states
,
[
0
,
query_states
.
shape
[
-
1
]
-
value_states
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states_padded
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
True
,
# for bsz = 1
attn_output
=
torch
.
zeros
(
bsz
*
q_len
,
self
.
num_heads
,
self
.
v_head_dim
,
device
=
hidden_states
.
device
)
b_start_loc
=
torch
.
zeros
(
bsz
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
b_seq_len
=
torch
.
full
((
bsz
,),
q_len
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
max_input_len
=
q_len
context_attention_fwd
(
q
=
query_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
k
=
key_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
q_head_dim
),
v
=
value_states
.
squeeze
(
0
).
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
),
o
=
attn_output
,
b_start_loc
=
b_start_loc
,
b_seq_len
=
b_seq_len
,
max_input_len
=
max_input_len
,
is_causal
=
True
)
if
self
.
q_head_dim
!=
self
.
v_head_dim
:
attn_output
=
attn_output
[:,
:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
[:,
:,
:
self
.
v_head_dim
]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
...
...
@@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
...
...
ktransformers/operators/dynamic_attention.py
View file @
ed843741
...
...
@@ -17,7 +17,10 @@ import logging
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
except
:
print
(
"falsh attn not found"
)
import
math
...
...
ktransformers/operators/linear.py
View file @
ed843741
...
...
@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import
cpuinfer_ext
from
ktransformers.operators.cpuinfer
import
CPUInfer
from
ktransformers.server.config.config
import
Config
from
typing
import
Dict
,
Tuple
,
Optional
,
Union
import
numpy
as
np
#class KLinearBase(BaseInjectedModule, ABC):
class
KLinearBase
(
ABC
):
...
...
@@ -176,6 +178,195 @@ class KLinearTorch(KLinearBase):
if
self
.
has_bias
:
self
.
bias
=
None
class
KLinearQ8
(
KLinearBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
device
:
str
=
"cuda"
,
group_size
:
int
=
128
,
# 增大分组大小,减少量化噪声
percentile
:
float
=
99.99
,
# 新增:对异常值进行截断的百分位数
**
kwargs
,
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
compute_dtype
=
torch
.
float32
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
bias
=
None
self
.
loaded
=
False
self
.
group_size
=
group_size
self
.
percentile
=
percentile
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
out_device
=
x
.
device
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
compute_dtype
)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant
=
self
.
_dequantize_weight
(
self
.
weight
,
self
.
weight_scale
,
bits
=
8
)
out
=
x
@
weight_dequant
.
T
if
self
.
has_bias
:
out
=
out
+
self
.
bias
return
out
.
to
(
dtype
=
orig_dtype
,
device
=
out_device
)
def
_dequantize_weight
(
self
,
q_matrix
,
scales
,
bits
=
8
):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if
not
isinstance
(
q_matrix
,
torch
.
Tensor
):
q_matrix
=
torch
.
tensor
(
q_matrix
,
dtype
=
torch
.
int8
)
if
not
isinstance
(
scales
,
torch
.
Tensor
):
scales
=
torch
.
tensor
(
scales
,
dtype
=
torch
.
float32
)
# Convert to correct dtype if needed
if
q_matrix
.
dtype
!=
torch
.
int8
:
q_matrix
=
q_matrix
.
to
(
torch
.
int8
)
if
scales
.
dtype
!=
torch
.
float32
:
scales
=
scales
.
to
(
torch
.
float32
)
# For Q4, ensure the values stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
7
,
7
)
# Get matrix shape
rows
,
cols
=
q_matrix
.
shape
# Convert to float32
dequant_matrix
=
q_matrix
.
to
(
torch
.
float32
)
# Create broadcasted scales: reshape scales to [1, cols] for broadcasting
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix
=
dequant_matrix
*
scales_broadcast
return
dequant_matrix
def
_quantize_weight
(
self
,
matrix
,
bits
=
8
):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if
not
isinstance
(
matrix
,
torch
.
Tensor
):
matrix
=
torch
.
tensor
(
matrix
,
dtype
=
torch
.
float32
)
# Convert to float32 if needed
if
matrix
.
dtype
!=
torch
.
float32
:
matrix
=
matrix
.
to
(
torch
.
float32
)
# Get matrix shape
rows
,
cols
=
matrix
.
shape
# Determine quantization parameters based on bits
if
bits
==
8
:
# Q8: range is -127 to 127
max_int
=
127
qtype
=
torch
.
int8
elif
bits
==
4
:
# Q4: range is -7 to 7 (using 4-bit signed integers)
max_int
=
7
qtype
=
torch
.
int8
# We'll still use int8 storage but limit to 4-bit range
else
:
raise
ValueError
(
"Quantization bits must be either 8 or 4"
)
# Initialize results and scale factors
q_matrix
=
torch
.
zeros_like
(
matrix
,
dtype
=
qtype
)
scales
=
torch
.
zeros
(
cols
,
dtype
=
torch
.
float32
,
device
=
matrix
.
device
)
# Initialize scale factors
scales
=
torch
.
zeros
(
cols
,
dtype
=
torch
.
float32
,
device
=
matrix
.
device
)
# Calculate max absolute value for each column
max_abs_vals
,
_
=
torch
.
max
(
torch
.
abs
(
matrix
),
dim
=
0
)
# Handle zero columns (avoid division by zero)
zero_cols
=
max_abs_vals
==
0
max_abs_vals
[
zero_cols
]
=
1.0
# Calculate scale factors for all columns at once
scales
=
max_abs_vals
/
max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply quantization to the entire matrix at once
q_matrix
=
torch
.
round
(
matrix
/
scales_broadcast
).
to
(
qtype
)
# For Q4, clamp values to ensure they stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
max_int
,
max_int
)
return
q_matrix
,
scales
def
load
(
self
,
w
:
Union
[
Dict
,
nn
.
Parameter
,
Tuple
,
None
]
=
None
,
device
:
Optional
[
str
]
=
None
):
if
self
.
loaded
:
return
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
try
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
)
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
try
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
compute_dtype
).
to
(
device
)
self
.
has_bias
=
True
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
weight
,
self
.
weight_scale
=
self
.
_quantize_weight
(
weight
,
bits
=
8
)
self
.
weight
=
self
.
weight
.
to
(
device
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
device
)
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
self
.
loaded
=
True
def
unload
(
self
):
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
_orig_weight
=
None
if
self
.
has_bias
:
self
.
bias
=
None
self
.
loaded
=
False
class
KLinearFP8
(
KLinearBase
):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
...
...
@@ -468,6 +659,7 @@ LINEAR_MAP = {
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearQ8"
:
KLinearQ8
,
}
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
...
...
ktransformers/operators/models.py
View file @
ed843741
...
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer
,
DeepseekV2MoE
,
)
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
...
...
@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
#
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
...
...
@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f
=
0
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment