Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
52fa671c
"src/vscode:/vscode.git/clone" did not exist on "074798b2997a6f1a329924b400a0db924e8e6735"
Unverified
Commit
52fa671c
authored
Mar 26, 2025
by
Yuhao Tsui
Committed by
GitHub
Mar 26, 2025
Browse files
Merge branch 'kvcache-ai:main' into main
parents
e5694f91
f142f4df
Changes
52
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1072 additions
and
43 deletions
+1072
-43
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
...ormers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
+1
-1
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
...ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
+5
-0
ktransformers/ktransformers_ext/ext_bindings.cpp
ktransformers/ktransformers_ext/ext_bindings.cpp
+0
-1
ktransformers/ktransformers_ext/vendors/cuda.h
ktransformers/ktransformers_ext/vendors/cuda.h
+15
-0
ktransformers/ktransformers_ext/vendors/hip.h
ktransformers/ktransformers_ext/vendors/hip.h
+172
-0
ktransformers/ktransformers_ext/vendors/musa.h
ktransformers/ktransformers_ext/vendors/musa.h
+137
-0
ktransformers/ktransformers_ext/vendors/vendor.h
ktransformers/ktransformers_ext/vendors/vendor.h
+13
-0
ktransformers/local_chat.py
ktransformers/local_chat.py
+3
-2
ktransformers/local_chat_test.py
ktransformers/local_chat_test.py
+171
-0
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+9
-4
ktransformers/operators/cpuinfer.py
ktransformers/operators/cpuinfer.py
+6
-1
ktransformers/operators/dynamic_attention.py
ktransformers/operators/dynamic_attention.py
+8
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+3
-1
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+137
-17
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+177
-7
ktransformers/operators/models.py
ktransformers/operators/models.py
+4
-2
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+3
-3
ktransformers/operators/triton_attention_prefill.py
ktransformers/operators/triton_attention_prefill.py
+206
-0
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
...ormers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+1
-1
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
...imize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
+1
-1
No files found.
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
View file @
52fa671c
...
@@ -39,7 +39,7 @@ using I4 = Vec<int, 4>;
...
@@ -39,7 +39,7 @@ using I4 = Vec<int, 4>;
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if
(
defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
) || defined (__HIP_PLATFORM_AMD__)
// No support for async
// No support for async
#else
#else
...
...
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
View file @
52fa671c
...
@@ -8,6 +8,11 @@
...
@@ -8,6 +8,11 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
typedef
__hip_bfloat16
nv_bfloat16
;
typedef
__hip_bfloat162
nv_bfloat162
;
#endif
namespace
gptq_marlin
{
namespace
gptq_marlin
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
...
ktransformers/ktransformers_ext/ext_bindings.cpp
View file @
52fa671c
...
@@ -9,7 +9,6 @@
...
@@ -9,7 +9,6 @@
**/
**/
// Python bindings
// Python bindings
#include "cpu_backend/cpuinfer.h"
#include "cpu_backend/cpuinfer.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/linear.h"
...
...
ktransformers/ktransformers_ext/vendors/cuda.h
0 → 100644
View file @
52fa671c
#pragma once
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
ktransformers/ktransformers_ext/vendors/hip.h
0 → 100644
View file @
52fa671c
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef
hip_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/vendors/musa.h
0 → 100644
View file @
52fa671c
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef
mt_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/vendors/vendor.h
0 → 100644
View file @
52fa671c
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
ktransformers/local_chat.py
View file @
52fa671c
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
custom_models
=
{
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
@@ -56,7 +57,7 @@ def local_chat(
...
@@ -56,7 +57,7 @@ def local_chat(
model_path
:
str
|
None
=
None
,
model_path
:
str
|
None
=
None
,
optimize_config_path
:
str
=
None
,
optimize_config_path
:
str
=
None
,
gguf_path
:
str
|
None
=
None
,
gguf_path
:
str
|
None
=
None
,
max_new_tokens
:
int
=
3
00
,
max_new_tokens
:
int
=
10
00
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
use_cuda_graph
:
bool
=
True
,
use_cuda_graph
:
bool
=
True
,
prompt_file
:
str
|
None
=
None
,
prompt_file
:
str
|
None
=
None
,
...
@@ -169,7 +170,7 @@ def local_chat(
...
@@ -169,7 +170,7 @@ def local_chat(
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
and
device_manager
.
gpu_vendor
==
GPUVendor
.
NVIDIA
:
generated
=
prefill_and_generate
(
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
...
...
ktransformers/local_chat_test.py
0 → 100644
View file @
52fa671c
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import
os
import
platform
import
sys
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
insert
(
0
,
project_dir
)
import
torch
import
logging
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
AutoModelForCausalLM
,
GenerationConfig
,
TextStreamer
,
)
import
json
import
fire
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV3ForCausalLM"
:
DeepseekV3ForCausalLM
,
"Qwen2MoeForCausalLM"
:
Qwen2MoeForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"MixtralForCausalLM"
:
MixtralForCausalLM
,
}
ktransformer_rules_dir
=
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/optimize/optimize_rules/"
)
default_optimize_rules
=
{
"DeepseekV2ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V2-Chat.yaml"
,
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct.yaml"
,
"LlamaForCausalLM"
:
ktransformer_rules_dir
+
"Internlm2_5-7b-Chat-1m.yaml"
,
"MixtralForCausalLM"
:
ktransformer_rules_dir
+
"Mixtral.yaml"
,
}
def
local_chat
(
model_path
:
str
|
None
=
None
,
optimize_config_path
:
str
=
None
,
gguf_path
:
str
|
None
=
None
,
max_new_tokens
:
int
=
1000
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
use_cuda_graph
:
bool
=
True
,
prompt_file
:
str
|
None
=
None
,
mode
:
str
=
"normal"
,
force_think
:
bool
=
False
,
chunk_prefill_size
:
int
=
8192
):
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
mode
==
'long_context'
:
assert
config
.
architectures
[
0
]
==
"LlamaForCausalLM"
,
"only LlamaForCausalLM support long_context mode"
torch
.
set_default_dtype
(
torch
.
float16
)
else
:
torch
.
set_default_dtype
(
config
.
torch_dtype
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
in
custom_models
:
print
(
"using custom modeling_xxx.py."
)
if
(
"Qwen2Moe"
in
config
.
architectures
[
0
]
):
# Qwen2Moe must use flash_attention_2 to avoid overflow.
config
.
_attn_implementation
=
"flash_attention_2"
if
"Llama"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"eager"
if
"Mixtral"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"flash_attention_2"
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
else
:
model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
)
if
optimize_config_path
is
None
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
optimize_config_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_config_path
=
input
(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf
(
model
,
optimize_config_path
,
gguf_path
,
config
)
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
except
Exception
as
e
:
print
(
f
"generation config can't auto create, make default. Message:
{
e
}
"
)
gen_config
=
GenerationConfig
(
temperature
=
0.6
,
top_p
=
0.95
,
do_sample
=
True
)
model
.
generation_config
=
gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if
model
.
generation_config
.
pad_token_id
is
None
:
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
model
.
eval
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
system
=
platform
.
system
()
if
system
==
"Windows"
:
os
.
system
(
"cls"
)
else
:
os
.
system
(
"clear"
)
if
prompt_file
!=
None
:
assert
os
.
path
.
isfile
(
prompt_file
),
"prompt file not exist"
print
(
f
"prompt file is
{
prompt_file
}
"
)
content
=
open
(
prompt_file
,
"r"
).
read
()
else
:
content
=
"Please write a piece of quicksort code in C++."
print
(
'Start Testing...(1 round)'
)
print
(
'Prompt:'
,
content
)
while
True
:
messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
if
force_think
:
token_thinks
=
torch
.
tensor
([
tokenizer
.
encode
(
"<think>
\\
n"
,
add_special_tokens
=
False
)],
device
=
input_tensor
.
device
)
input_tensor
=
torch
.
cat
(
[
input_tensor
,
token_thinks
],
dim
=
1
)
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
)
break
if
__name__
==
"__main__"
:
fire
.
Fire
(
local_chat
)
ktransformers/operators/attention.py
View file @
52fa671c
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import
logging
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_func
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
try
:
from
flash_attn
import
flash_attn_func
except
:
pass
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention_prefill
import
context_attention_fwd
import
os
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
if
flashinfer_enabled
:
...
@@ -589,8 +595,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -589,8 +595,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
return
self
.
forward_windows
(
return
self
.
forward_windows
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
...
...
ktransformers/operators/cpuinfer.py
View file @
52fa671c
...
@@ -727,7 +727,12 @@ class CPUInferKVCache:
...
@@ -727,7 +727,12 @@ class CPUInferKVCache:
class
CPUInfer
:
class
CPUInfer
:
cpuinfer
=
None
cpuinfer
=
None
cur_backend_thread_num
=
0
def
__init__
(
self
,
thread_num
):
def
__init__
(
self
,
thread_num
):
if
thread_num
>
CPUInfer
.
cur_backend_thread_num
:
CPUInfer
.
cur_backend_thread_num
=
thread_num
del
CPUInfer
.
cpuinfer
CPUInfer
.
cpuinfer
=
cpuinfer_ext
.
CPUInfer
(
thread_num
)
CPUInfer
.
cpuinfer
=
cpuinfer_ext
.
CPUInfer
(
thread_num
)
def
submit
(
self
,
task
):
def
submit
(
self
,
task
):
...
...
ktransformers/operators/dynamic_attention.py
View file @
52fa671c
...
@@ -17,7 +17,10 @@ import logging
...
@@ -17,7 +17,10 @@ import logging
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
except
:
print
(
"falsh attn not found"
)
import
math
import
math
...
@@ -26,6 +29,7 @@ import json
...
@@ -26,6 +29,7 @@ import json
class
DynamicScaledDotProductAttention
:
class
DynamicScaledDotProductAttention
:
remaining_length
:
int
remaining_length
:
int
cpu_infer
=
None
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -180,7 +184,9 @@ class DynamicScaledDotProductAttention:
...
@@ -180,7 +184,9 @@ class DynamicScaledDotProductAttention:
self
.
preselect_block_num
=
0
# block_num before preselect
self
.
preselect_block_num
=
0
# block_num before preselect
self
.
evict_tokens
=
0
self
.
evict_tokens
=
0
self
.
cpu_infer
=
CPUInfer
(
threads_num
)
if
DynamicScaledDotProductAttention
.
cpu_infer
is
None
:
DynamicScaledDotProductAttention
.
cpu_infer
=
CPUInfer
(
threads_num
)
self
.
cpu_infer
=
DynamicScaledDotProductAttention
.
cpu_infer
self
.
local_thread
=
CPUInferKVCache
(
self
.
local_thread
=
CPUInferKVCache
(
self
.
layer_num
,
self
.
layer_num
,
self
.
kv_head_num
,
self
.
kv_head_num
,
...
...
ktransformers/operators/experts.py
View file @
52fa671c
...
@@ -120,7 +120,7 @@ class KExpertsCPU(KExpertsBase):
...
@@ -120,7 +120,7 @@ class KExpertsCPU(KExpertsBase):
output_gpu_map
:
dict
=
{}
# Manage output tensor buffer on different gpu
output_gpu_map
:
dict
=
{}
# Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
#gguf_loader:GGUFLoader = None
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
CPU_INFER
=
None
def
__init__
(
def
__init__
(
self
,
self
,
key
:
str
,
key
:
str
,
...
@@ -133,6 +133,8 @@ class KExpertsCPU(KExpertsBase):
...
@@ -133,6 +133,8 @@ class KExpertsCPU(KExpertsBase):
**
kwargs
**
kwargs
):
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
if
KExpertsCPU
.
CPU_INFER
is
None
:
KExpertsCPU
.
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
#if KExpertsCPU.gguf_loader is None:
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self
.
gguf_loader
=
gguf_loader
self
.
gguf_loader
=
gguf_loader
...
...
ktransformers/operators/gate.py
View file @
52fa671c
from
typing
import
Optional
from
typing
import
Any
,
Union
from
torch
import
nn
import
numpy
as
np
import
numpy.typing
as
npt
from
torch
import
Tensor
,
nn
import
torch.nn.functional
as
F
import
torch
import
torch
import
sys
,
os
import
torch.nn.functional
as
F
import
os
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Debug"
))
import
cpuinfer_ext
from
cpuinfer_ext.moe
import
MOEConfig
,
MOE
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.linear
import
KTransformersLinear
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
import
time
# class Base(BaseInjectedModule, ABC):
# class Base(BaseInjectedModule, ABC):
...
@@ -100,18 +89,147 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
...
@@ -100,18 +89,147 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
return
self
.
orig_module
.
forward
(
hidden_states
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
if
isinstance
(
w
,
dict
):
self
.
weight_type
=
w
[
"weight_type"
]
self
.
e_score_correction_bias_type
=
w
[
"e_score_correction_bias_type"
]
self
.
orig_module
.
weight
=
nn
.
Parameter
(
w
[
"weight"
])
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
w
[
"e_score_correction_bias"
])
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
))
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
))
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
# This is used by the Deepseek-V2 and Deepseek-V3 model
#@torch.compile(dynamic=True)
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
routed_scaling_factor
:
float
=
1.0
,
scoring_func
:
str
=
"sigmoid"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
num_token
=
scores
.
shape
[
0
]
if
e_score_correction_bias
is
not
None
:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores
=
scores
scores
=
scores
+
e_score_correction_bias
.
unsqueeze
(
0
)
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
))
else
:
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
#float("-inf")) # [n, e]
if
e_score_correction_bias
is
not
None
:
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
# Use original unbiased scores for the routing weights
topk_weights
=
original_scores
.
gather
(
1
,
topk_ids
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
topk
>
1
and
renormalize
:
denominator
=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weights
=
topk_weights
/
denominator
topk_weights
=
topk_weights
*
routed_scaling_factor
# must multiply the scaling factor
return
topk_ids
.
to
(
torch
.
long
),
topk_weights
.
to
(
torch
.
float32
)
class
KMoEGateDeepSeekV3
(
BaseInjectedModule
,
KMoEGateBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
generate_op
:
str
|
None
=
"KLinearMarlin"
,
prefill_device
:
str
=
"cuda"
,
prefill_op
:
str
|
None
=
"KLinearMarlin"
,
use_quant
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
self
.
prefill_device
=
prefill_device
self
.
generate_op
=
generate_op
self
.
prefill_op
=
prefill_op
self
.
is_windows
=
os
.
name
==
'nt'
self
.
use_quant
=
use_quant
if
not
self
.
is_windows
and
use_quant
:
print
(
"injecting gate_linear"
)
self
.
gate_linear
=
nn
.
Linear
(
self
.
gating_dim
,
self
.
n_routed_experts
,
device
=
generate_device
)
self
.
gate_linear
=
KTransformersLinear
(
key
+
".ffn_gate_inp"
,
gguf_loader
,
config
,
self
.
gate_linear
,
#orig_module
generate_device
,
generate_op
,
prefill_device
,
prefill_op
)
else
:
self
.
gate_linear
=
None
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
if
True
or
self
.
is_windows
:
return
self
.
orig_module
.
forward
(
hidden_states
)
return
self
.
orig_module
.
forward
(
hidden_states
)
bsz
,
seq_len
,
h
=
hidden_states
.
shape
### compute gating score
hidden_states
=
hidden_states
.
view
(
-
1
,
h
)
if
self
.
use_quant
:
logits
=
self
.
gate_linear
.
forward
(
hidden_states
)
else
:
logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
),
None
)
return
grouped_topk
(
hidden_states
,
logits
,
self
.
top_k
,
self
.
norm_topk_prob
,
self
.
n_group
,
self
.
topk_group
,
self
.
routed_scaling_factor
,
"sigmoid"
,
self
.
e_score_correction_bias
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
...
@@ -125,6 +243,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
...
@@ -125,6 +243,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
raise
ValueError
(
"Invalid weight type"
)
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
))
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
))
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
))
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
))
if
not
self
.
is_windows
and
self
.
use_quant
:
self
.
gate_linear
.
load
(
self
.
orig_module
.
weight
)
def
unload
(
self
):
def
unload
(
self
):
if
self
.
weight
is
not
None
:
if
self
.
weight
is
not
None
:
...
...
ktransformers/operators/linear.py
View file @
52fa671c
...
@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
...
@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import
cpuinfer_ext
import
cpuinfer_ext
from
ktransformers.operators.cpuinfer
import
CPUInfer
from
ktransformers.operators.cpuinfer
import
CPUInfer
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
typing
import
Dict
,
Tuple
,
Optional
,
Union
import
numpy
as
np
#class KLinearBase(BaseInjectedModule, ABC):
#class KLinearBase(BaseInjectedModule, ABC):
class
KLinearBase
(
ABC
):
class
KLinearBase
(
ABC
):
...
@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
...
@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
if
self
.
has_bias
:
if
self
.
has_bias
:
self
.
bias
=
None
self
.
bias
=
None
class
KLinearQ8
(
KLinearBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
device
:
str
=
"cuda"
,
**
kwargs
,
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
compute_dtype
=
torch
.
float32
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
bias
=
None
self
.
loaded
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
out_device
=
x
.
device
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
compute_dtype
)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant
=
self
.
_dequantize_weight
(
self
.
weight
,
self
.
weight_scale
,
bits
=
8
)
out
=
x
@
weight_dequant
.
T
if
self
.
has_bias
:
out
=
out
+
self
.
bias
return
out
.
to
(
dtype
=
orig_dtype
,
device
=
out_device
)
def
_dequantize_weight
(
self
,
q_matrix
,
scales
,
bits
=
8
):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if
not
isinstance
(
q_matrix
,
torch
.
Tensor
):
q_matrix
=
torch
.
tensor
(
q_matrix
,
dtype
=
torch
.
int8
)
if
not
isinstance
(
scales
,
torch
.
Tensor
):
scales
=
torch
.
tensor
(
scales
,
dtype
=
torch
.
float32
)
# Convert to correct dtype if needed
if
q_matrix
.
dtype
!=
torch
.
int8
:
q_matrix
=
q_matrix
.
to
(
torch
.
int8
)
if
scales
.
dtype
!=
torch
.
float32
:
scales
=
scales
.
to
(
torch
.
float32
)
# For Q4, ensure the values stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
7
,
7
)
rows
,
cols
=
q_matrix
.
shape
dequant_matrix
=
q_matrix
.
to
(
torch
.
float32
)
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix
=
dequant_matrix
*
scales_broadcast
return
dequant_matrix
def
_quantize_weight
(
self
,
matrix
,
bits
=
8
):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if
not
isinstance
(
matrix
,
torch
.
Tensor
):
matrix
=
torch
.
tensor
(
matrix
,
dtype
=
torch
.
float32
)
# Convert to float32 if needed
if
matrix
.
dtype
!=
torch
.
float32
:
matrix
=
matrix
.
to
(
torch
.
float32
)
# Get matrix shape
rows
,
cols
=
matrix
.
shape
# Determine quantization parameters based on bits
if
bits
==
8
:
max_int
=
127
qtype
=
torch
.
int8
elif
bits
==
4
:
max_int
=
7
qtype
=
torch
.
int8
# We'll still use int8 storage but limit to 4-bit range, wait for native support
else
:
raise
ValueError
(
"Quantization bits must be either 8 or 4"
)
scales
=
torch
.
zeros
(
cols
,
dtype
=
torch
.
float32
,
device
=
matrix
.
device
)
# Calculate max absolute value for each column
max_abs_vals
,
_
=
torch
.
max
(
torch
.
abs
(
matrix
),
dim
=
0
)
# Handle zero columns (avoid division by zero)
zero_cols
=
max_abs_vals
==
0
max_abs_vals
[
zero_cols
]
=
1.0
# Calculate scale factors for all columns at once
scales
=
max_abs_vals
/
max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply quantization to the entire matrix at once
q_matrix
=
torch
.
round
(
matrix
/
scales_broadcast
).
to
(
qtype
)
# For Q4, clamp values to ensure they stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
max_int
,
max_int
)
return
q_matrix
,
scales
def
load
(
self
,
w
:
Union
[
Dict
,
nn
.
Parameter
,
Tuple
,
None
]
=
None
,
device
:
Optional
[
str
]
=
None
):
if
self
.
loaded
:
return
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
try
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
)
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
try
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
compute_dtype
).
to
(
device
)
self
.
has_bias
=
True
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
weight
,
self
.
weight_scale
=
self
.
_quantize_weight
(
weight
,
bits
=
8
)
self
.
weight
=
self
.
weight
.
to
(
device
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
device
)
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
self
.
loaded
=
True
def
unload
(
self
):
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
_orig_weight
=
None
if
self
.
has_bias
:
self
.
bias
=
None
self
.
loaded
=
False
class
KLinearFP8
(
KLinearBase
):
class
KLinearFP8
(
KLinearBase
):
# this kernel requires special handling for weight
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
# Please load the weight file downloaded from KVCache.AI
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
sort_indices
:
torch
.
Tensor
has_bias
:
bool
has_bias
:
bool
weight
:
torch
.
Tensor
weight
:
torch
.
Tensor
scale_w
:
torch
.
Tensor
bias
:
torch
.
Tensor
bias
:
torch
.
Tensor
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -360,7 +528,7 @@ class KLinearMarlin(KLinearBase):
...
@@ -360,7 +528,7 @@ class KLinearMarlin(KLinearBase):
self
.
workspace
=
None
self
.
workspace
=
None
class
KLinearCPUInfer
(
KLinearBase
):
class
KLinearCPUInfer
(
KLinearBase
):
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
CPU_INFER
=
None
def
__init__
(
def
__init__
(
self
,
self
,
key
:
str
,
key
:
str
,
...
@@ -374,6 +542,8 @@ class KLinearCPUInfer(KLinearBase):
...
@@ -374,6 +542,8 @@ class KLinearCPUInfer(KLinearBase):
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
if
KLinearCPUInfer
.
CPU_INFER
is
None
:
KLinearCPUInfer
.
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
self
.
has_bias
=
False
self
.
has_bias
=
False
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
w
=
None
self
.
w
=
None
...
@@ -466,6 +636,7 @@ LINEAR_MAP = {
...
@@ -466,6 +636,7 @@ LINEAR_MAP = {
"KLinearTorch"
:
KLinearTorch
,
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearQ8"
:
KLinearQ8
,
}
}
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
...
@@ -475,7 +646,6 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
...
@@ -475,7 +646,6 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
gguf_loader
:
GGUFLoader
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
orig_module
:
nn
.
Module
,
# device: str = "cuda",
generate_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
generate_op
:
str
|
None
=
"KLinearMarlin"
,
generate_op
:
str
|
None
=
"KLinearMarlin"
,
prefill_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
...
...
ktransformers/operators/models.py
View file @
52fa671c
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer
,
DeepseekV2DecoderLayer
,
DeepseekV2MoE
,
DeepseekV2MoE
,
)
)
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
...
@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
if
per_layer_prefill_flag
:
causal_mask
=
None
causal_mask
=
None
else
:
else
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
#
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
...
@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
...
@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f
=
0
t_f
=
0
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
cur_device
=
self
.
transfer_map
[
i
]
...
...
ktransformers/operators/triton_attention.py
View file @
52fa671c
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
# Tanh is just a scaled sigmoid
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
# TODO: support hip
#
if
is_hip_
and Lk >= 576:
if
device_manager
.
gpu_vendor
==
GPUVendor
.
AMD
and
Lk
>=
576
:
#
BLOCK = 16
BLOCK
=
16
if
Lk
==
576
:
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
...
...
ktransformers/operators/triton_attention_prefill.py
0 → 100644
View file @
52fa671c
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
import
triton
import
triton.language
as
tl
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
Out
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
end_n
=
(
cur_batch_seq_len
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
cur_batch_seq_len
)
)
for
start_n
in
range
(
0
,
block_mask
*
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
)
&
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:])),
0
,
float
(
"-inf"
),
)
else
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
,
0
,
float
(
"-inf"
)
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
8
:
BLOCK
=
128
else
:
BLOCK
=
64
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
),
BLOCK_N
=
BLOCK
,
IS_CAUSAL
=
is_causal
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lk
=
Lk
,
)
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
View file @
52fa671c
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
replace
:
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
kwargs
:
generate_device
:
"
cu
da
"
generate_device
:
"
c
p
u"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
prefill_op
:
"
KLinearTorch"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
View file @
52fa671c
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
-
match
:
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
kwargs
:
kwargs
:
generate_device
:
"
cuda:0"
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment