Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
c009512a
Commit
c009512a
authored
Mar 13, 2025
by
Azure-Tang
Browse files
Merge branch 'main' into hip
parents
c1f13a69
4f22d726
Changes
121
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1428 additions
and
329 deletions
+1428
-329
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
+80
-72
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
...nsformers/ktransformers_ext/cpu_backend/vendors/README.md
+3
-0
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
+15
-0
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
+172
-0
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
+137
-0
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
+13
-0
ktransformers/ktransformers_ext/cuda/binding.cpp
ktransformers/ktransformers_ext/cuda/binding.cpp
+55
-24
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
+0
-35
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
+631
-140
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
+9
-9
ktransformers/ktransformers_ext/cuda/test_dequant.py
ktransformers/ktransformers_ext/cuda/test_dequant.py
+16
-0
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
...xt/operators/custom_marlin/quantize/utils/marlin_utils.py
+2
-2
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py
...ext/operators/custom_marlin/quantize/utils/quant_utils.py
+2
-10
ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp
...mers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp
+2
-0
ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp
...ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp
+3
-0
ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp
...transformers_ext/operators/kvcache/kvcache_read_write.cpp
+2
-0
ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp
...ers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp
+2
-0
ktransformers/ktransformers_ext/triton/fp8gemm.py
ktransformers/ktransformers_ext/triton/fp8gemm.py
+193
-0
ktransformers/local_chat.py
ktransformers/local_chat.py
+28
-23
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+63
-14
No files found.
ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
View file @
c009512a
...
@@ -7,24 +7,32 @@
...
@@ -7,24 +7,32 @@
* @LastEditTime : 2024-08-07 09:47:43
* @LastEditTime : 2024-08-07 09:47:43
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#ifndef CPUINFER_CPUINFER_H
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#include <atomic>
#include <atomic>
#include <condition_variable>
#include <condition_variable>
#include <functional>
#include <functional>
#include <mutex>
#include <mutex>
#include <queue>
#include <queue>
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#elif KTRANSFORMERS_USE_ROCM
#define __HIP_PLATFORM_AMD__
#include "vendors/hip.h"
#endif
#include "backend.h"
#include "backend.h"
#include "task_queue.h"
#include "task_queue.h"
#include "../vendors/vendor.h"
#include "../vendors/vendor.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-impl.h"
class
CPUInfer
{
class
CPUInfer
{
public:
public:
CPUInfer
(
int
thread_num
)
{
CPUInfer
(
int
thread_num
)
{
backend_
=
new
Backend
(
thread_num
-
1
);
backend_
=
new
Backend
(
thread_num
-
1
);
...
@@ -76,6 +84,6 @@ class CPUInfer {
...
@@ -76,6 +84,6 @@ class CPUInfer {
public:
public:
Backend
*
backend_
;
Backend
*
backend_
;
TaskQueue
*
task_queue_
;
TaskQueue
*
task_queue_
;
};
};
#endif
#endif
\ No newline at end of file
\ No newline at end of file
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
0 → 100644
View file @
c009512a
## TODO
This directory can be removed after updating the version of
`llama.cpp`
.
\ No newline at end of file
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
0 → 100644
View file @
c009512a
#pragma once
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
ktransformers/ktransformers_ext/cpu_backend/vendors/hip.h
0 → 100644
View file @
c009512a
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef
hip_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
0 → 100644
View file @
c009512a
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef
mt_bfloat16
nv_bfloat16
;
ktransformers/ktransformers_ext/cpu_backend/vendors/vendor.h
0 → 100644
View file @
c009512a
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
ktransformers/ktransformers_ext/cuda/binding.cpp
View file @
c009512a
/**
/**
* @Description :
* @Description :
* @Author : Azure-Tang
* @Author : Azure-Tang
, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @Version : 0.2.2
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
**/
#include "custom_gguf/ops.h"
#include "custom_gguf/ops.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "gptq_marlin/ops.h"
#include "gptq_marlin/ops.h"
#endif
// Python bindings
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl.h>
...
@@ -19,22 +19,53 @@
...
@@ -19,22 +19,53 @@
// namespace py = pybind11;
// namespace py = pybind11;
PYBIND11_MODULE
(
KTransformersOps
,
m
)
{
PYBIND11_MODULE
(
KTransformersOps
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q8_0"
,
[](
const
intptr_t
data
,
int
num_bytes
,
int
blk_size
,
const
int
ele_per_blk
,
torch
::
Device
device
,
py
::
object
target_dtype
)
{
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
torch
::
Dtype
dtype
=
torch
::
python
::
detail
::
py_object_to_dtype
(
target_dtype
);
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
return
dequantize_q8_0
((
int8_t
*
)
data
,
num_bytes
,
blk_size
,
ele_per_blk
,
device
,
dtype
);
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
},
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"ele_per_blk"
),
py
::
arg
(
"device"
),
py
::
arg
(
"target_dtype"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q6_k"
,
[](
const
intptr_t
data
,
int
num_bytes
,
int
blk_size
,
const
int
ele_per_blk
,
torch
::
Device
device
,
py
::
object
target_dtype
)
{
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
torch
::
Dtype
dtype
=
torch
::
python
::
detail
::
py_object_to_dtype
(
target_dtype
);
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
return
dequantize_q6_k
((
int8_t
*
)
data
,
num_bytes
,
blk_size
,
ele_per_blk
,
device
,
dtype
);
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
},
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"ele_per_blk"
),
py
::
arg
(
"device"
),
py
::
arg
(
"target_dtype"
));
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q5_k"
,
[](
const
intptr_t
data
,
int
num_bytes
,
int
blk_size
,
const
int
ele_per_blk
,
torch
::
Device
device
,
py
::
object
target_dtype
)
{
torch
::
Dtype
dtype
=
torch
::
python
::
detail
::
py_object_to_dtype
(
target_dtype
);
return
dequantize_q5_k
((
int8_t
*
)
data
,
num_bytes
,
blk_size
,
ele_per_blk
,
device
,
dtype
);
},
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"ele_per_blk"
),
py
::
arg
(
"device"
),
py
::
arg
(
"target_dtype"
));
m
.
def
(
"dequantize_q4_k"
,
[](
const
intptr_t
data
,
int
num_bytes
,
int
blk_size
,
const
int
ele_per_blk
,
torch
::
Device
device
,
py
::
object
target_dtype
)
{
torch
::
Dtype
dtype
=
torch
::
python
::
detail
::
py_object_to_dtype
(
target_dtype
);
return
dequantize_q4_k
((
int8_t
*
)
data
,
num_bytes
,
blk_size
,
ele_per_blk
,
device
,
dtype
);
},
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"ele_per_blk"
),
py
::
arg
(
"device"
),
py
::
arg
(
"target_dtype"
));
m
.
def
(
"dequantize_q3_k"
,
[](
const
intptr_t
data
,
int
num_bytes
,
int
blk_size
,
const
int
ele_per_blk
,
torch
::
Device
device
,
py
::
object
target_dtype
)
{
torch
::
Dtype
dtype
=
torch
::
python
::
detail
::
py_object_to_dtype
(
target_dtype
);
return
dequantize_q3_k
((
int8_t
*
)
data
,
num_bytes
,
blk_size
,
ele_per_blk
,
device
,
dtype
);
},
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"ele_per_blk"
),
py
::
arg
(
"device"
),
py
::
arg
(
"target_dtype"
));
m
.
def
(
"dequantize_q2_k"
,
[](
const
intptr_t
data
,
int
num_bytes
,
int
blk_size
,
const
int
ele_per_blk
,
torch
::
Device
device
,
py
::
object
target_dtype
)
{
torch
::
Dtype
dtype
=
torch
::
python
::
detail
::
py_object_to_dtype
(
target_dtype
);
return
dequantize_q2_k
((
int8_t
*
)
data
,
num_bytes
,
blk_size
,
ele_per_blk
,
device
,
dtype
);
},
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"ele_per_blk"
),
py
::
arg
(
"device"
),
py
::
arg
(
"target_dtype"
));
m
.
def
(
"dequantize_iq4_xs"
,
[](
const
intptr_t
data
,
int
num_bytes
,
int
blk_size
,
const
int
ele_per_blk
,
torch
::
Device
device
,
py
::
object
target_dtype
)
{
torch
::
Dtype
dtype
=
torch
::
python
::
detail
::
py_object_to_dtype
(
target_dtype
);
return
dequantize_iq4_xs
((
int8_t
*
)
data
,
num_bytes
,
blk_size
,
ele_per_blk
,
device
,
dtype
);
},
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"num_bytes"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"ele_per_blk"
),
py
::
arg
(
"device"
),
py
::
arg
(
"target_dtype"
));
#ifdef KTRANSFORMERS_USE_CUDA
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"size_n"
),
py
::
arg
(
"size_k"
),
py
::
arg
(
"is_k_full"
));
py
::
arg
(
"size_n"
),
py
::
arg
(
"size_k"
),
py
::
arg
(
"is_k_full"
));
#endif
}
}
ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp
deleted
100644 → 0
View file @
c1f13a69
#include "ops.h"
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/library.h>
#include <torch/extension.h>
#include <torch/torch.h>
// namespace py = pybind11;
int
test
(){
return
5
;
}
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
device
);
PYBIND11_MODULE
(
cudaops
,
m
)
{
m
.
def
(
"dequantize_q8_0"
,
&
dequantize_q8_0
,
"Function to dequantize q8_0 data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q6_k"
,
&
dequantize_q6_k
,
"Function to dequantize q6_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q5_k"
,
&
dequantize_q5_k
,
"Function to dequantize q5_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q4_k"
,
&
dequantize_q4_k
,
"Function to dequantize q4_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q3_k"
,
&
dequantize_q3_k
,
"Function to dequantize q3_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_q2_k"
,
&
dequantize_q2_k
,
"Function to dequantize q2_k data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"dequantize_iq4_xs"
,
&
dequantize_iq4_xs
,
"Function to dequantize iq4_xs data."
,
py
::
arg
(
"data"
),
py
::
arg
(
"blk_size"
),
py
::
arg
(
"device"
));
m
.
def
(
"test"
,
&
test
,
"Function to test."
);
}
ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
View file @
c009512a
This diff is collapsed.
Click to expand it.
ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h
View file @
c009512a
...
@@ -13,10 +13,10 @@
...
@@ -13,10 +13,10 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <torch/torch.h>
torch
::
Tensor
dequantize_q8_0
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
devic
e
);
torch
::
Tensor
dequantize_q8_0
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
torch
::
Device
device
,
const
torch
::
Dtype
target_dtyp
e
);
torch
::
Tensor
dequantize_q6_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
devic
e
);
torch
::
Tensor
dequantize_q6_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
torch
::
Device
device
,
const
torch
::
Dtype
target_dtyp
e
);
torch
::
Tensor
dequantize_q5_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
devic
e
);
torch
::
Tensor
dequantize_q5_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
torch
::
Device
device
,
const
torch
::
Dtype
target_dtyp
e
);
torch
::
Tensor
dequantize_q4_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
devic
e
);
torch
::
Tensor
dequantize_q4_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
torch
::
Device
device
,
const
torch
::
Dtype
target_dtyp
e
);
torch
::
Tensor
dequantize_q3_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
devic
e
);
torch
::
Tensor
dequantize_q3_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
torch
::
Device
device
,
const
torch
::
Dtype
target_dtyp
e
);
torch
::
Tensor
dequantize_q2_k
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
devic
e
);
torch
::
Tensor
dequantize_q2_k
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
torch
::
Device
device
,
const
torch
::
Dtype
target_dtyp
e
);
torch
::
Tensor
dequantize_iq4_xs
(
torch
::
Tensor
data
,
int
blk_size
,
torch
::
Device
devic
e
);
torch
::
Tensor
dequantize_iq4_xs
(
const
int8_t
*
data
,
const
int
num_bytes
,
const
int
blk_size
,
const
int
ele_per_blk
,
const
torch
::
Device
device
,
const
torch
::
Dtype
target_dtyp
e
);
ktransformers/ktransformers_ext/cuda/test_dequant.py
0 → 100644
View file @
c009512a
import
os
import
sys
sys
.
path
.
insert
(
0
,
"/home/zbx/ktransformers"
)
from
ktransformers.util.custom_gguf
import
GGUFLoader
import
torch
gguf_loader_1
=
GGUFLoader
(
"/mnt/data/model/DeepseekV3-q4km-gguf"
)
gguf_loader_2
=
GGUFLoader
(
"/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/"
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
tensor_1
=
gguf_loader_1
.
load_gguf_tensor
(
"blk.0.attn_kv_a_mqa.weight"
,
"cuda"
)
tensor_2
=
gguf_loader_2
.
load_gguf_tensor
(
"blk.0.attn_kv_a_mqa.weight"
,
"cuda"
)
print
(
tensor_1
[
0
,
-
64
:])
print
(
tensor_2
[
0
,
-
64
:])
\ No newline at end of file
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py
View file @
c009512a
...
@@ -90,7 +90,7 @@ def marlin_quantize(
...
@@ -90,7 +90,7 @@ def marlin_quantize(
assert
group_size
<=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# For act_order, sort the "weights" and "g_idx" so that group ids are
...
@@ -107,7 +107,7 @@ def marlin_quantize(
...
@@ -107,7 +107,7 @@ def marlin_quantize(
marlin_scale_perm_single
[
num_bits
])
marlin_scale_perm_single
[
num_bits
])
# Create result
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
res_list
=
[
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
...
...
ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py
View file @
c009512a
...
@@ -11,8 +11,7 @@ def get_pack_factor(num_bits):
...
@@ -11,8 +11,7 @@ def get_pack_factor(num_bits):
return
32
//
num_bits
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
def
permute_rows
(
q_w
:
torch
.
Tensor
,
group_size
:
int
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
orig_device
=
q_w
.
device
k_size
,
_
=
q_w
.
shape
k_size
,
_
=
q_w
.
shape
...
@@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
...
@@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
w_ref
=
w_ref
[
rand_perm
,
:].
contiguous
()
return
(
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
...
@@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
q_w
+=
half_q_val
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
# Restore original shapes
# Restore original shapes
if
group_size
<
size_k
:
if
group_size
<
size_k
:
...
@@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
return
w
return
w
q_w
=
reshape_w
(
q_w
)
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
...
@@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
group_size
)
return
(
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
...
...
ktransformers/ktransformers_ext/operators/kvcache/kvcache_attn.cpp
View file @
c009512a
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include "kvcache.h"
#include "kvcache.h"
#include <chrono>
void
KVCache
::
attention_kvhead_
(
const
uint16_t
*
q_in_data
,
ggml_fp16_t
*
output
,
void
KVCache
::
attention_kvhead_
(
const
uint16_t
*
q_in_data
,
ggml_fp16_t
*
output
,
float
*
attn_lse
,
int
batch_size
,
float
*
attn_lse
,
int
batch_size
,
Backend
*
backend
)
{
Backend
*
backend
)
{
...
...
ktransformers/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp
View file @
c009512a
...
@@ -9,6 +9,9 @@
...
@@ -9,6 +9,9 @@
**/
**/
#include "kvcache.h"
#include "kvcache.h"
#include <chrono>
void
KVCache
::
load_kvcache
(
std
::
string
tensor_file_path
,
Backend
*
backend
)
{
void
KVCache
::
load_kvcache
(
std
::
string
tensor_file_path
,
Backend
*
backend
)
{
// Timer start
// Timer start
auto
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
...
...
ktransformers/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp
View file @
c009512a
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include "kvcache.h"
#include "kvcache.h"
#include <chrono>
void
KVCache
::
get_anchor_one_block
(
ggml_fp16_t
*
anchor
,
int
layer_id
,
void
KVCache
::
get_anchor_one_block
(
ggml_fp16_t
*
anchor
,
int
layer_id
,
int
block_idx
,
Backend
*
backend
)
{
int
block_idx
,
Backend
*
backend
)
{
// Timer start
// Timer start
...
...
ktransformers/ktransformers_ext/operators/kvcache/kvcache_utils.cpp
View file @
c009512a
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include "kvcache.h"
#include "kvcache.h"
#include <chrono>
std
::
string
ggml_type_to_string
(
ggml_type
type
)
{
std
::
string
ggml_type_to_string
(
ggml_type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
GGML_TYPE_F32
:
case
GGML_TYPE_F32
:
...
...
ktransformers/ktransformers_ext/triton/fp8gemm.py
0 → 100644
View file @
c009512a
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
from
typing
import
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
triton
import
Config
@
triton
.
jit
def
act_quant_kernel
(
x_ptr
,
y_ptr
,
s_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
Args:
x_ptr (triton.Pointer): Pointer to the input tensor.
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
Returns:
None
"""
pid
=
tl
.
program_id
(
axis
=
0
)
offs
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
x
=
tl
.
load
(
x_ptr
+
offs
).
to
(
tl
.
float32
)
s
=
tl
.
max
(
tl
.
abs
(
x
))
/
448.
y
=
x
/
s
y
=
y
.
to
(
y_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_ptr
+
offs
,
y
)
tl
.
store
(
s_ptr
+
pid
,
s
)
def
act_quant
(
x
:
torch
.
Tensor
,
block_size
:
int
=
128
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert
x
.
is_contiguous
(),
'Input tensor must be contiguous'
assert
x
.
size
(
-
1
)
%
block_size
==
0
,
f
'Last dimension size must be divisible by block_size (block_size=
{
block_size
}
)'
y
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
float8_e4m3fn
)
s
=
x
.
new_empty
(
*
x
.
size
()[:
-
1
],
x
.
size
(
-
1
)
//
block_size
,
dtype
=
torch
.
float32
)
grid
=
lambda
meta
:
(
triton
.
cdiv
(
x
.
numel
(),
meta
[
'BLOCK_SIZE'
]),
)
act_quant_kernel
[
grid
](
x
,
y
,
s
,
BLOCK_SIZE
=
block_size
)
return
y
,
s
@
triton
.
jit
def
weight_dequant_kernel
(
x_ptr
,
s_ptr
,
y_ptr
,
M
,
N
,
BLOCK_SIZE
:
tl
.
constexpr
):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_n
=
tl
.
program_id
(
axis
=
1
)
n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE
)
offs_m
=
pid_m
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
offs_n
=
pid_n
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
offs
=
offs_m
[:,
None
]
*
N
+
offs_n
[
None
,
:]
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
)
x
=
tl
.
load
(
x_ptr
+
offs
,
mask
=
mask
).
to
(
tl
.
float32
)
s
=
tl
.
load
(
s_ptr
+
pid_m
*
n
+
pid_n
)
y
=
x
*
s
tl
.
store
(
y_ptr
+
offs
,
y
,
mask
=
mask
)
def
weight_dequant
(
x
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
block_size
:
int
=
128
)
->
torch
.
Tensor
:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert
x
.
is_contiguous
()
and
s
.
is_contiguous
(),
'Input tensors must be contiguous'
assert
x
.
dim
()
==
2
and
s
.
dim
()
==
2
,
'Input tensors must have 2 dimensions'
M
,
N
=
x
.
size
()
y
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
get_default_dtype
())
grid
=
lambda
meta
:
(
triton
.
cdiv
(
M
,
meta
[
'BLOCK_SIZE'
]),
triton
.
cdiv
(
N
,
meta
[
'BLOCK_SIZE'
]))
with
torch
.
cuda
.
device
(
x
.
device
):
weight_dequant_kernel
[
grid
](
x
,
s
,
y
,
M
,
N
,
BLOCK_SIZE
=
block_size
)
return
y
fp8_gemm_configs
=
[
Config
({
'BLOCK_SIZE_M'
:
block_m
,
'BLOCK_SIZE_N'
:
block_n
,
'BLOCK_SIZE_K'
:
128
},
num_stages
=
num_stages
,
num_warps
=
8
)
for
block_m
in
[
16
,
32
,
64
]
for
block_n
in
[
32
,
64
,
128
]
for
num_stages
in
[
3
,
4
,
5
,
6
]
]
@
triton
.
autotune
(
configs
=
fp8_gemm_configs
,
key
=
[
'N'
,
'K'
])
@
triton
.
jit
def
fp8_gemm_kernel
(
a_ptr
,
b_ptr
,
c_ptr
,
a_s_ptr
,
b_s_ptr
,
M
,
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
Returns:
None
"""
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_n
=
tl
.
program_id
(
axis
=
1
)
k
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
offs_m
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_n
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
offs_m
[:,
None
]
*
K
+
offs_k
[
None
,
:]
b_ptrs
=
b_ptr
+
offs_n
[
None
,
:]
*
K
+
offs_k
[:,
None
]
a_s_ptrs
=
a_s_ptr
+
offs_m
*
k
b_s_ptrs
=
b_s_ptr
+
(
offs_n
//
BLOCK_SIZE_K
)
*
k
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
i
in
range
(
k
):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
i
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
i
*
BLOCK_SIZE_K
,
other
=
0.0
)
a_s
=
tl
.
load
(
a_s_ptrs
)
b_s
=
tl
.
load
(
b_s_ptrs
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
b_ptrs
+=
BLOCK_SIZE_K
a_s_ptrs
+=
1
b_s_ptrs
+=
1
c
=
accumulator
.
to
(
c_ptr
.
dtype
.
element_ty
)
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
offs_m
[:,
None
]
*
N
+
offs_n
[
None
,
:]
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
mask
)
def
fp8_gemm
(
a
:
torch
.
Tensor
,
a_s
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b_s
:
torch
.
Tensor
):
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert
a
.
is_contiguous
()
and
b
.
is_contiguous
(),
'Input tensors must be contiguous'
assert
a_s
.
is_contiguous
()
and
b_s
.
is_contiguous
(),
'Scaling factor tensors must be contiguous'
K
=
a
.
size
(
-
1
)
M
=
a
.
numel
()
//
K
N
=
b
.
size
(
0
)
c
=
a
.
new_empty
(
*
a
.
size
()[:
-
1
],
N
,
dtype
=
torch
.
get_default_dtype
())
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_SIZE_M'
]),
triton
.
cdiv
(
N
,
META
[
'BLOCK_SIZE_N'
]))
fp8_gemm_kernel
[
grid
](
a
,
b
,
c
,
a_s
,
b_s
,
M
,
N
,
K
)
return
c
\ No newline at end of file
ktransformers/local_chat.py
View file @
c009512a
...
@@ -28,8 +28,9 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
...
@@ -28,8 +28,9 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.util.utils
import
prefill_and_generate
from
ktransformers.util.utils
import
prefill_and_generate
,
get_compute_capability
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
custom_models
=
{
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
...
@@ -53,7 +54,7 @@ default_optimize_rules = {
...
@@ -53,7 +54,7 @@ default_optimize_rules = {
def
local_chat
(
def
local_chat
(
model_path
:
str
|
None
=
None
,
model_path
:
str
|
None
=
None
,
optimize_
rule
_path
:
str
=
None
,
optimize_
config
_path
:
str
=
None
,
gguf_path
:
str
|
None
=
None
,
gguf_path
:
str
|
None
=
None
,
max_new_tokens
:
int
=
300
,
max_new_tokens
:
int
=
300
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
cpu_infer
:
int
=
Config
().
cpu_infer
,
...
@@ -61,9 +62,9 @@ def local_chat(
...
@@ -61,9 +62,9 @@ def local_chat(
prompt_file
:
str
|
None
=
None
,
prompt_file
:
str
|
None
=
None
,
mode
:
str
=
"normal"
,
mode
:
str
=
"normal"
,
force_think
:
bool
=
False
,
force_think
:
bool
=
False
,
chunk_prefill_size
:
int
=
8192
):
):
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
Config
().
cpu_infer
=
cpu_infer
Config
().
cpu_infer
=
cpu_infer
...
@@ -94,12 +95,12 @@ def local_chat(
...
@@ -94,12 +95,12 @@ def local_chat(
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
config
,
trust_remote_code
=
True
,
attn_implementation
=
"flash_attention_2"
)
)
if
optimize_
rule
_path
is
None
:
if
optimize_
config
_path
is
None
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
if
config
.
architectures
[
0
]
in
default_optimize_rules
:
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
print
(
"using default_optimize_rule for"
,
config
.
architectures
[
0
])
optimize_
rule
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
optimize_
config
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
else
:
optimize_
rule
_path
=
input
(
optimize_
config
_path
=
input
(
"please input the path of your rule file(yaml file containing optimize rules):"
"please input the path of your rule file(yaml file containing optimize rules):"
)
)
...
@@ -107,15 +108,15 @@ def local_chat(
...
@@ -107,15 +108,15 @@ def local_chat(
gguf_path
=
input
(
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
)
optimize_and_load_gguf
(
model
,
optimize_
rule
_path
,
gguf_path
,
config
)
optimize_and_load_gguf
(
model
,
optimize_
config
_path
,
gguf_path
,
config
)
try
:
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
model_path
)
except
:
except
Exception
as
e
:
print
(
f
"generation config can't auto create, make default. Message:
{
e
}
"
)
gen_config
=
GenerationConfig
(
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.6
,
temperature
=
0.7
,
top_p
=
0.95
,
top_p
=
0.9
,
do_sample
=
True
do_sample
=
True
)
)
model
.
generation_config
=
gen_config
model
.
generation_config
=
gen_config
...
@@ -167,11 +168,15 @@ def local_chat(
...
@@ -167,11 +168,15 @@ def local_chat(
if
mode
==
'long_context'
:
if
mode
==
'long_context'
:
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
assert
Config
().
long_context_config
[
'max_seq_len'
]
>
input_tensor
.
shape
[
1
]
+
max_new_tokens
,
\
"please change max_seq_len in ~/.ktransformers/config.yaml"
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch
.
set_default_dtype
(
torch
.
bfloat16
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
)
# TODO: Remove this, replace dtype using config
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
,
force_think
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
)
)
...
...
ktransformers/models/custom_cache.py
View file @
c009512a
...
@@ -51,13 +51,34 @@ class StaticCache(transformers.StaticCache):
...
@@ -51,13 +51,34 @@ class StaticCache(transformers.StaticCache):
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
if
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
self
.
page_size
=
64
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
self
.
max_pages
=
(
self
.
max_cache_len
+
self
.
page_size
-
1
)
//
self
.
page_size
key_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
qk_rope_head_dim
)
latent_shape
=
(
self
.
max_pages
,
self
.
page_size
,
1
,
config
.
kv_lora_rank
+
config
.
qk_rope_head_dim
)
value_shape
=
(
max_batch_size
,
1
,
self
.
max_cache_len
,
config
.
kv_lora_rank
)
self
.
kv_lora_rank
=
config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
config
.
qk_rope_head_dim
# TODO: support real page table
self
.
page_table_map
=
dict
()
self
.
page_table_list
=
[]
for
idx
in
range
(
config
.
num_hidden_layers
):
if
isinstance
(
device
,
dict
):
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
else
:
target_device
=
device
if
target_device
not
in
self
.
page_table_map
:
page_table
=
torch
.
zeros
((
max_batch_size
,
self
.
max_pages
),
dtype
=
torch
.
int32
,
device
=
target_device
)
for
seq_id
in
range
(
max_batch_size
):
page_table
[
seq_id
,
:]
=
torch
.
arange
(
seq_id
*
self
.
max_pages
,
seq_id
*
self
.
max_pages
+
self
.
max_pages
,
dtype
=
torch
.
int32
,
device
=
target_device
)
self
.
page_table_map
[
target_device
]
=
page_table
self
.
page_table_list
.
append
(
self
.
page_table_map
[
target_device
])
self
.
is_MLA
=
True
self
.
is_page
=
True
else
:
else
:
key_shape
=
cache_shape
key_shape
=
cache_shape
value_shape
=
cache_shape
value_shape
=
cache_shape
self
.
is_MLA
=
False
self
.
past_tokens
=
[]
self
.
past_tokens
=
[]
self
.
num_hidden_layers
=
config
.
num_hidden_layers
self
.
num_hidden_layers
=
config
.
num_hidden_layers
...
@@ -68,10 +89,17 @@ class StaticCache(transformers.StaticCache):
...
@@ -68,10 +89,17 @@ class StaticCache(transformers.StaticCache):
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
target_device
=
device
[
f
"blk.
{
idx
}
.self_attn"
][
"generate_device"
]
else
:
else
:
target_device
=
device
target_device
=
device
if
self
.
is_MLA
:
new_layer_key_cache
=
torch
.
zeros
(
latent_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_value_cache
=
None
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
else
:
new_layer_key_cache
=
torch
.
zeros
(
key_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_key_cache
=
torch
.
zeros
(
key_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_value_cache
=
torch
.
zeros
(
value_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
new_layer_value_cache
=
torch
.
zeros
(
value_shape
,
dtype
=
self
.
dtype
,
device
=
target_device
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_value_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_value_cache
)
self
.
key_cache
.
append
(
new_layer_key_cache
)
self
.
key_cache
.
append
(
new_layer_key_cache
)
self
.
value_cache
.
append
(
new_layer_value_cache
)
self
.
value_cache
.
append
(
new_layer_value_cache
)
self
.
past_tokens
.
append
(
0
)
self
.
past_tokens
.
append
(
0
)
...
@@ -104,10 +132,18 @@ class StaticCache(transformers.StaticCache):
...
@@ -104,10 +132,18 @@ class StaticCache(transformers.StaticCache):
cache_position
=
cache_kwargs
.
get
(
"cache_position"
)
cache_position
=
cache_kwargs
.
get
(
"cache_position"
)
k_out
=
self
.
key_cache
[
layer_idx
]
k_out
=
self
.
key_cache
[
layer_idx
]
v_out
=
self
.
value_cache
[
layer_idx
]
v_out
=
self
.
value_cache
[
layer_idx
]
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
#print(cache_position)
#print(cache_position)
if
self
.
is_MLA
:
page_idx
=
cache_position
//
self
.
page_size
page_offset
=
cache_position
%
self
.
page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out
[
page_idx
,
page_offset
,
:,
:
self
.
kv_lora_rank
]
=
key_states
k_out
[
page_idx
,
page_offset
,
:,
self
.
kv_lora_rank
:]
=
value_states
return
k_out
,
self
.
page_table_list
[
layer_idx
]
else
:
k_out
[:,
:,
cache_position
]
=
key_states
k_out
[:,
:,
cache_position
]
=
key_states
v_out
[:,
:,
cache_position
]
=
value_states
v_out
[:,
:,
cache_position
]
=
value_states
self
.
past_tokens
[
layer_idx
]
+=
cache_position
.
size
(
0
)
return
k_out
,
v_out
return
k_out
,
v_out
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
...
@@ -134,7 +170,20 @@ class StaticCache(transformers.StaticCache):
...
@@ -134,7 +170,20 @@ class StaticCache(transformers.StaticCache):
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
# In-place ops prevent breaking the static address
# In-place ops prevent breaking the static address
self
.
key_cache
[
layer_idx
].
zero_
()
self
.
key_cache
[
layer_idx
].
zero_
()
if
self
.
value_cache
[
layer_idx
]
is
not
None
:
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
value_cache
[
layer_idx
].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
0
def
remove_suffix
(
self
,
start_pos
):
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
# In-place ops prevent breaking the static address
if
self
.
is_MLA
:
k_cache
=
self
.
key_cache
[
layer_idx
]
k_cache
.
view
(
-
1
,
k_cache
.
shape
[
-
1
])[
start_pos
:].
zero_
()
else
:
self
.
key_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
value_cache
[
layer_idx
][...,
start_pos
:,
:].
zero_
()
self
.
past_tokens
[
layer_idx
]
=
start_pos
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
def
get_max_cache_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
"""Returns the maximum shape of the cache."""
"""Returns the maximum shape of the cache."""
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment