Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
52fa671c
Unverified
Commit
52fa671c
authored
Mar 26, 2025
by
Yuhao Tsui
Committed by
GitHub
Mar 26, 2025
Browse files
Merge branch 'kvcache-ai:main' into main
parents
e5694f91
f142f4df
Changes
52
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1072 additions
and
43 deletions
+1072
-43
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
...ormers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
+1
-1
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
...ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
+5
-0
ktransformers/ktransformers_ext/ext_bindings.cpp
ktransformers/ktransformers_ext/ext_bindings.cpp
+0
-1
ktransformers/ktransformers_ext/vendors/cuda.h
ktransformers/ktransformers_ext/vendors/cuda.h
+15
-0
ktransformers/ktransformers_ext/vendors/hip.h
ktransformers/ktransformers_ext/vendors/hip.h
+172
-0
ktransformers/ktransformers_ext/vendors/musa.h
ktransformers/ktransformers_ext/vendors/musa.h
+137
-0
ktransformers/ktransformers_ext/vendors/vendor.h
ktransformers/ktransformers_ext/vendors/vendor.h
+13
-0
ktransformers/local_chat.py
ktransformers/local_chat.py
+3
-2
ktransformers/local_chat_test.py
ktransformers/local_chat_test.py
+171
-0
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+9
-4
ktransformers/operators/cpuinfer.py
ktransformers/operators/cpuinfer.py
+6
-1
ktransformers/operators/dynamic_attention.py
ktransformers/operators/dynamic_attention.py
+8
-2
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+3
-1
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+137
-17
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+177
-7
ktransformers/operators/models.py
ktransformers/operators/models.py
+4
-2
ktransformers/operators/triton_attention.py
ktransformers/operators/triton_attention.py
+3
-3
ktransformers/operators/triton_attention_prefill.py
ktransformers/operators/triton_attention_prefill.py
+206
-0
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
...ormers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+1
-1
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
...imize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
+1
-1
No files found.
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh
View file @
52fa671c
...
...
@@ -39,7 +39,7 @@ using I4 = Vec<int, 4>;
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if
(
defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
) || defined (__HIP_PLATFORM_AMD__)
// No support for async
#else
...
...
ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh
View file @
52fa671c
...
...
@@ -8,6 +8,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
typedef
__hip_bfloat16
nv_bfloat16
;
typedef
__hip_bfloat162
nv_bfloat162
;
#endif
namespace
gptq_marlin
{
template
<
typename
scalar_t
>
...
...
ktransformers/ktransformers_ext/ext_bindings.cpp
View file @
52fa671c
...
...
@@ -9,7 +9,6 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
...
...
ktransformers/ktransformers_ext/vendors/cuda.h
0 → 100644
View file @
52fa671c
#pragma once
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
ktransformers/ktransformers_ext/vendors/hip.h
0 → 100644
View file @
52fa671c
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef
hip_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/vendors/musa.h
0 → 100644
View file @
52fa671c
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef
mt_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/vendors/vendor.h
0 → 100644
View file @
52fa671c
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
ktransformers/local_chat.py
View file @
52fa671c
...
...
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
...
@@ -56,7 +57,7 @@ def local_chat(
model_path
:
str
|
None
=
None
,
optimize_config_path
:
str
=
None
,
gguf_path
:
str
|
None
=
None
,
max_new_tokens
:
int
=
3
00
,
max_new_tokens
:
int
=
10
00
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
use_cuda_graph
:
bool
=
True
,
prompt_file
:
str
|
None
=
None
,
...
...
@@ -169,7 +170,7 @@ def local_chat(
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
and
device_manager
.
gpu_vendor
==
GPUVendor
.
NVIDIA
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
...
...
ktransformers/local_chat_test.py
0 → 100644
View file @
52fa671c
"""
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import
os
import
platform
import
sys
project_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
insert
(
0
,
project_dir
)
import
torch
import
logging
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
AutoModelForCausalLM
,
GenerationConfig
,
TextStreamer
,
)
import
json
import
fire
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV3ForCausalLM"
:
DeepseekV3ForCausalLM
,
"Qwen2MoeForCausalLM"
:
Qwen2MoeForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"MixtralForCausalLM"
:
MixtralForCausalLM
,
}
ktransformer_rules_dir
=
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/optimize/optimize_rules/"
)
default_optimize_rules
=
{
"DeepseekV2ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V2-Chat.yaml"
,
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct.yaml"
,
"LlamaForCausalLM"
:
ktransformer_rules_dir
+
"Internlm2_5-7b-Chat-1m.yaml"
,
"MixtralForCausalLM"
:
ktransformer_rules_dir
+
"Mixtral.yaml"
,
}
def
local_chat
(
model_path
:
str
|
None
=
None
,
optimize_config_path
:
str
=
None
,
gguf_path
:
str
|
None
=
None
,
max_new_tokens
:
int
=
1000
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
use_cuda_graph
:
bool
=
True
,
prompt_file
:
str
|
None
=
None
,
mode
:
str
=
"normal"
,
force_think
:
bool
=
False
,
chunk_prefill_size
:
int
=
8192
):
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
mode
==
'long_context'
:
assert
config
.
architectures
[
0
]
==
"LlamaForCausalLM"
,
"only LlamaForCausalLM support long_context mode"
torch
.
set_default_dtype
(
torch
.
float16
)
else
:
torch
.
set_default_dtype
(
config
.
torch_dtype
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
in
custom_models
:
print
(
"using custom modeling_xxx.py."
)
if
(
"Qwen2Moe"
in
config
.
architectures
[
0
]
):
# Qwen2Moe must use flash_attention_2 to avoid overflow.
config
.
_attn_implementation
=
"flash_attention_2"
if
"Llama"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"eager"
if
"Mixtral"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"flash_attention_2"
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
else
:
model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
)
if
optimize_config_path
is
None
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
optimize_config_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_config_path
=
input
(
"please input the path of your rule file(yaml file containing optimize rules):"
)
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf
(
model
,
optimize_config_path
,
gguf_path
,
config
)
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
except
Exception
as
e
:
print
(
f
"generation config can't auto create, make default. Message:
{
e
}
"
)
gen_config
=
GenerationConfig
(
temperature
=
0.6
,
top_p
=
0.95
,
do_sample
=
True
)
model
.
generation_config
=
gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if
model
.
generation_config
.
pad_token_id
is
None
:
model
.
generation_config
.
pad_token_id
=
model
.
generation_config
.
eos_token_id
model
.
eval
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
system
=
platform
.
system
()
if
system
==
"Windows"
:
os
.
system
(
"cls"
)
else
:
os
.
system
(
"clear"
)
if
prompt_file
!=
None
:
assert
os
.
path
.
isfile
(
prompt_file
),
"prompt file not exist"
print
(
f
"prompt file is
{
prompt_file
}
"
)
content
=
open
(
prompt_file
,
"r"
).
read
()
else
:
content
=
"Please write a piece of quicksort code in C++."
print
(
'Start Testing...(1 round)'
)
print
(
'Prompt:'
,
content
)
while
True
:
messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
input_tensor
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
)
if
force_think
:
token_thinks
=
torch
.
tensor
([
tokenizer
.
encode
(
"<think>
\\
n"
,
add_special_tokens
=
False
)],
device
=
input_tensor
.
device
)
input_tensor
=
torch
.
cat
(
[
input_tensor
,
token_thinks
],
dim
=
1
)
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
)
break
if
__name__
==
"__main__"
:
fire
.
Fire
(
local_chat
)
ktransformers/operators/attention.py
View file @
52fa671c
...
...
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
from
flash_attn
import
flash_attn_func
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
try
:
from
flash_attn
import
flash_attn_func
except
:
pass
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
from
ktransformers.operators.triton_attention_prefill
import
context_attention_fwd
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
...
...
@@ -589,8 +595,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
...
...
ktransformers/operators/cpuinfer.py
View file @
52fa671c
...
...
@@ -727,7 +727,12 @@ class CPUInferKVCache:
class
CPUInfer
:
cpuinfer
=
None
cur_backend_thread_num
=
0
def
__init__
(
self
,
thread_num
):
if
thread_num
>
CPUInfer
.
cur_backend_thread_num
:
CPUInfer
.
cur_backend_thread_num
=
thread_num
del
CPUInfer
.
cpuinfer
CPUInfer
.
cpuinfer
=
cpuinfer_ext
.
CPUInfer
(
thread_num
)
def
submit
(
self
,
task
):
...
...
ktransformers/operators/dynamic_attention.py
View file @
52fa671c
...
...
@@ -17,7 +17,10 @@ import logging
logger
=
logging
.
getLogger
(
"dynamic_attention"
)
sys
.
path
.
append
(
os
.
path
.
dirname
(
__file__
)
+
"/../ktransformers_ext/cpu_backend"
)
from
ktransformers.operators.cpuinfer
import
CPUInfer
,
CPUInferKVCache
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_with_kvcache
except
:
print
(
"falsh attn not found"
)
import
math
...
...
@@ -26,6 +29,7 @@ import json
class
DynamicScaledDotProductAttention
:
remaining_length
:
int
cpu_infer
=
None
def
__init__
(
self
,
...
...
@@ -180,7 +184,9 @@ class DynamicScaledDotProductAttention:
self
.
preselect_block_num
=
0
# block_num before preselect
self
.
evict_tokens
=
0
self
.
cpu_infer
=
CPUInfer
(
threads_num
)
if
DynamicScaledDotProductAttention
.
cpu_infer
is
None
:
DynamicScaledDotProductAttention
.
cpu_infer
=
CPUInfer
(
threads_num
)
self
.
cpu_infer
=
DynamicScaledDotProductAttention
.
cpu_infer
self
.
local_thread
=
CPUInferKVCache
(
self
.
layer_num
,
self
.
kv_head_num
,
...
...
ktransformers/operators/experts.py
View file @
52fa671c
...
...
@@ -120,7 +120,7 @@ class KExpertsCPU(KExpertsBase):
output_gpu_map
:
dict
=
{}
# Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
CPU_INFER
=
None
def
__init__
(
self
,
key
:
str
,
...
...
@@ -133,6 +133,8 @@ class KExpertsCPU(KExpertsBase):
**
kwargs
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
if
KExpertsCPU
.
CPU_INFER
is
None
:
KExpertsCPU
.
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self
.
gguf_loader
=
gguf_loader
...
...
ktransformers/operators/gate.py
View file @
52fa671c
from
typing
import
Any
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
from
torch
import
Tensor
,
nn
import
torch.nn.functional
as
F
from
typing
import
Optional
from
torch
import
nn
import
torch
import
sys
,
os
import
torch.nn.functional
as
F
import
os
from
ktransformers.operators.base_operator
import
BaseInjectedModule
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Release"
))
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
,
"Debug"
))
import
cpuinfer_ext
from
cpuinfer_ext.moe
import
MOEConfig
,
MOE
import
ctypes
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.linear
import
KTransformersLinear
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
abc
import
ABC
,
abstractmethod
import
time
# class Base(BaseInjectedModule, ABC):
...
...
@@ -100,18 +89,147 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
return
self
.
orig_module
.
forward
(
hidden_states
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
if
isinstance
(
w
,
dict
):
self
.
weight_type
=
w
[
"weight_type"
]
self
.
e_score_correction_bias_type
=
w
[
"e_score_correction_bias_type"
]
self
.
orig_module
.
weight
=
nn
.
Parameter
(
w
[
"weight"
])
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
w
[
"e_score_correction_bias"
])
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
))
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
))
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
# This is used by the Deepseek-V2 and Deepseek-V3 model
#@torch.compile(dynamic=True)
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
routed_scaling_factor
:
float
=
1.0
,
scoring_func
:
str
=
"sigmoid"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
num_token
=
scores
.
shape
[
0
]
if
e_score_correction_bias
is
not
None
:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores
=
scores
scores
=
scores
+
e_score_correction_bias
.
unsqueeze
(
0
)
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
))
else
:
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
#float("-inf")) # [n, e]
if
e_score_correction_bias
is
not
None
:
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
# Use original unbiased scores for the routing weights
topk_weights
=
original_scores
.
gather
(
1
,
topk_ids
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
topk
>
1
and
renormalize
:
denominator
=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
topk_weights
=
topk_weights
/
denominator
topk_weights
=
topk_weights
*
routed_scaling_factor
# must multiply the scaling factor
return
topk_ids
.
to
(
torch
.
long
),
topk_weights
.
to
(
torch
.
float32
)
class
KMoEGateDeepSeekV3
(
BaseInjectedModule
,
KMoEGateBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
generate_op
:
str
|
None
=
"KLinearMarlin"
,
prefill_device
:
str
=
"cuda"
,
prefill_op
:
str
|
None
=
"KLinearMarlin"
,
use_quant
:
bool
=
False
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
self
.
generate_op
=
generate_op
self
.
prefill_op
=
prefill_op
self
.
is_windows
=
os
.
name
==
'nt'
self
.
use_quant
=
use_quant
if
not
self
.
is_windows
and
use_quant
:
print
(
"injecting gate_linear"
)
self
.
gate_linear
=
nn
.
Linear
(
self
.
gating_dim
,
self
.
n_routed_experts
,
device
=
generate_device
)
self
.
gate_linear
=
KTransformersLinear
(
key
+
".ffn_gate_inp"
,
gguf_loader
,
config
,
self
.
gate_linear
,
#orig_module
generate_device
,
generate_op
,
prefill_device
,
prefill_op
)
else
:
self
.
gate_linear
=
None
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
if
True
or
self
.
is_windows
:
return
self
.
orig_module
.
forward
(
hidden_states
)
bsz
,
seq_len
,
h
=
hidden_states
.
shape
### compute gating score
hidden_states
=
hidden_states
.
view
(
-
1
,
h
)
if
self
.
use_quant
:
logits
=
self
.
gate_linear
.
forward
(
hidden_states
)
else
:
logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
),
None
)
return
grouped_topk
(
hidden_states
,
logits
,
self
.
top_k
,
self
.
norm_topk_prob
,
self
.
n_group
,
self
.
topk_group
,
self
.
routed_scaling_factor
,
"sigmoid"
,
self
.
e_score_correction_bias
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
...
...
@@ -125,6 +243,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
))
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
))
if
not
self
.
is_windows
and
self
.
use_quant
:
self
.
gate_linear
.
load
(
self
.
orig_module
.
weight
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
...
...
ktransformers/operators/linear.py
View file @
52fa671c
...
...
@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import
cpuinfer_ext
from
ktransformers.operators.cpuinfer
import
CPUInfer
from
ktransformers.server.config.config
import
Config
from
typing
import
Dict
,
Tuple
,
Optional
,
Union
import
numpy
as
np
#class KLinearBase(BaseInjectedModule, ABC):
class
KLinearBase
(
ABC
):
...
...
@@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
if
self
.
has_bias
:
self
.
bias
=
None
class
KLinearQ8
(
KLinearBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
device
:
str
=
"cuda"
,
**
kwargs
,
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
compute_dtype
=
torch
.
float32
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
bias
=
None
self
.
loaded
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
out_device
=
x
.
device
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
compute_dtype
)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant
=
self
.
_dequantize_weight
(
self
.
weight
,
self
.
weight_scale
,
bits
=
8
)
out
=
x
@
weight_dequant
.
T
if
self
.
has_bias
:
out
=
out
+
self
.
bias
return
out
.
to
(
dtype
=
orig_dtype
,
device
=
out_device
)
def
_dequantize_weight
(
self
,
q_matrix
,
scales
,
bits
=
8
):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if
not
isinstance
(
q_matrix
,
torch
.
Tensor
):
q_matrix
=
torch
.
tensor
(
q_matrix
,
dtype
=
torch
.
int8
)
if
not
isinstance
(
scales
,
torch
.
Tensor
):
scales
=
torch
.
tensor
(
scales
,
dtype
=
torch
.
float32
)
# Convert to correct dtype if needed
if
q_matrix
.
dtype
!=
torch
.
int8
:
q_matrix
=
q_matrix
.
to
(
torch
.
int8
)
if
scales
.
dtype
!=
torch
.
float32
:
scales
=
scales
.
to
(
torch
.
float32
)
# For Q4, ensure the values stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
7
,
7
)
rows
,
cols
=
q_matrix
.
shape
dequant_matrix
=
q_matrix
.
to
(
torch
.
float32
)
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix
=
dequant_matrix
*
scales_broadcast
return
dequant_matrix
def
_quantize_weight
(
self
,
matrix
,
bits
=
8
):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if
not
isinstance
(
matrix
,
torch
.
Tensor
):
matrix
=
torch
.
tensor
(
matrix
,
dtype
=
torch
.
float32
)
# Convert to float32 if needed
if
matrix
.
dtype
!=
torch
.
float32
:
matrix
=
matrix
.
to
(
torch
.
float32
)
# Get matrix shape
rows
,
cols
=
matrix
.
shape
# Determine quantization parameters based on bits
if
bits
==
8
:
max_int
=
127
qtype
=
torch
.
int8
elif
bits
==
4
:
max_int
=
7
qtype
=
torch
.
int8
# We'll still use int8 storage but limit to 4-bit range, wait for native support
else
:
raise
ValueError
(
"Quantization bits must be either 8 or 4"
)
scales
=
torch
.
zeros
(
cols
,
dtype
=
torch
.
float32
,
device
=
matrix
.
device
)
# Calculate max absolute value for each column
max_abs_vals
,
_
=
torch
.
max
(
torch
.
abs
(
matrix
),
dim
=
0
)
# Handle zero columns (avoid division by zero)
zero_cols
=
max_abs_vals
==
0
max_abs_vals
[
zero_cols
]
=
1.0
# Calculate scale factors for all columns at once
scales
=
max_abs_vals
/
max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast
=
scales
.
view
(
1
,
cols
)
# Apply quantization to the entire matrix at once
q_matrix
=
torch
.
round
(
matrix
/
scales_broadcast
).
to
(
qtype
)
# For Q4, clamp values to ensure they stay within 4-bit range
if
bits
==
4
:
q_matrix
=
torch
.
clamp
(
q_matrix
,
-
max_int
,
max_int
)
return
q_matrix
,
scales
def
load
(
self
,
w
:
Union
[
Dict
,
nn
.
Parameter
,
Tuple
,
None
]
=
None
,
device
:
Optional
[
str
]
=
None
):
if
self
.
loaded
:
return
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
try
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
.
to
(
dtype
=
self
.
compute_dtype
)
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
try
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
).
view
(
self
.
out_features
,
self
.
in_features
)
except
:
weight
=
w
[
0
].
to
(
dtype
=
self
.
compute_dtype
)
self
.
bias
=
w
[
1
].
to
(
dtype
=
self
.
compute_dtype
).
to
(
device
)
self
.
has_bias
=
True
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
weight
,
self
.
weight_scale
=
self
.
_quantize_weight
(
weight
,
bits
=
8
)
self
.
weight
=
self
.
weight
.
to
(
device
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
device
)
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
self
.
loaded
=
True
def
unload
(
self
):
self
.
weight
=
None
self
.
weight_scale
=
None
self
.
weight_zero_point
=
None
self
.
_orig_weight
=
None
if
self
.
has_bias
:
self
.
bias
=
None
self
.
loaded
=
False
class
KLinearFP8
(
KLinearBase
):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
sort_indices
:
torch
.
Tensor
has_bias
:
bool
weight
:
torch
.
Tensor
scale_w
:
torch
.
Tensor
bias
:
torch
.
Tensor
def
__init__
(
self
,
...
...
@@ -360,7 +528,7 @@ class KLinearMarlin(KLinearBase):
self
.
workspace
=
None
class
KLinearCPUInfer
(
KLinearBase
):
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
CPU_INFER
=
None
def
__init__
(
self
,
key
:
str
,
...
...
@@ -374,6 +542,8 @@ class KLinearCPUInfer(KLinearBase):
**
kwargs
,
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
if
KLinearCPUInfer
.
CPU_INFER
is
None
:
KLinearCPUInfer
.
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
self
.
has_bias
=
False
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
w
=
None
...
...
@@ -466,6 +636,7 @@ LINEAR_MAP = {
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearQ8"
:
KLinearQ8
,
}
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
...
...
@@ -475,7 +646,6 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
# device: str = "cuda",
generate_device
:
str
=
"cuda"
,
generate_op
:
str
|
None
=
"KLinearMarlin"
,
prefill_device
:
str
=
"cuda"
,
...
...
ktransformers/operators/models.py
View file @
52fa671c
...
...
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer
,
DeepseekV2MoE
,
)
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
...
...
@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
or
device_manager
.
gpu_vendor
!=
GPUVendor
.
NVIDIA
:
#
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
...
...
@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f
=
0
for
i
,
decoder_layer
in
enumerate
(
self
.
layers
):
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if
self
.
transfer_map
is
not
None
and
i
in
self
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
transfer_map
[
i
]
...
...
ktransformers/operators/triton_attention.py
View file @
52fa671c
...
...
@@ -6,7 +6,7 @@
import
triton
import
triton.language
as
tl
from
ktransformers.util.vendors
import
device_manager
,
get_device
,
to_device
,
GPUVendor
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
...
...
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
# [TODO] work around shmem limit on MI3xx
# TODO: support hip
#
if
is_hip_
and Lk >= 576:
#
BLOCK = 16
if
device_manager
.
gpu_vendor
==
GPUVendor
.
AMD
and
Lk
>=
576
:
BLOCK
=
16
if
Lk
==
576
:
BLOCK_DMODEL
=
512
...
...
ktransformers/operators/triton_attention_prefill.py
0 → 100644
View file @
52fa671c
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
import
triton
import
triton.language
as
tl
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
Out
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
end_n
=
(
cur_batch_seq_len
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
cur_batch_seq_len
)
)
for
start_n
in
range
(
0
,
block_mask
*
end_n
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
)
&
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:])),
0
,
float
(
"-inf"
),
)
else
:
qk
+=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
,
0
,
float
(
"-inf"
)
)
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
8
:
BLOCK
=
128
else
:
BLOCK
=
64
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
),
BLOCK_N
=
BLOCK
,
IS_CAUSAL
=
is_causal
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lk
=
Lk
,
)
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
View file @
52fa671c
...
...
@@ -22,7 +22,7 @@
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
cu
da
"
generate_device
:
"
c
p
u"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
View file @
52fa671c
...
...
@@ -26,7 +26,7 @@
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment