Commit ed843741 authored by Azure-Tang's avatar Azure-Tang
Browse files

merge main; Add torch q8 linear

parent 6c4ed591
......@@ -32,6 +32,7 @@ endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
......@@ -201,6 +202,31 @@ endif()
# message(STATUS "Can't found CUDA lib")
# endif()
if (NOT EXISTS $ENV{ROCM_PATH})
if (NOT EXISTS /opt/rocm)
set(ROCM_PATH /usr)
else()
set(ROCM_PATH /opt/rocm)
endif()
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
......@@ -218,6 +244,14 @@ elseif (UNIX)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
if (KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED)
if(HIP_FOUND)
include_directories("${HIP_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
endif()
endif()
if (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
......@@ -258,6 +292,11 @@ elseif(UNIX)
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
if (KTRANSFORMERS_USE_ROCM)
add_compile_definitions(USE_HIP=1)
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
message(STATUS "Building for HIP")
endif()
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif()
......
......@@ -7,28 +7,32 @@
* @LastEditTime : 2024-08-07 09:47:43
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#endif
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#elif KTRANSFORMERS_USE_ROCM
#define __HIP_PLATFORM_AMD__
#include "vendors/hip.h"
#endif
#include "backend.h"
#include "task_queue.h"
#include "backend.h"
#include "task_queue.h"
#include "../vendors/vendor.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-impl.h"
class CPUInfer {
class CPUInfer {
public:
CPUInfer(int thread_num) {
backend_ = new Backend(thread_num - 1);
......@@ -80,6 +84,6 @@ class CPUInfer {
public:
Backend* backend_;
TaskQueue* task_queue_;
};
};
#endif
\ No newline at end of file
#endif
\ No newline at end of file
#pragma once
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef hip_bfloat16 nv_bfloat16;
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t
#define nv_bfloat16 mt_bfloat16
\ No newline at end of file
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16;
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
......@@ -15,6 +15,7 @@
#include <torch/torch.h>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
typedef hip_bfloat16 nv_bfloat16;
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
......
......@@ -36,7 +36,7 @@ inline std::string str(T x) {
namespace gptq_marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
......
......@@ -39,7 +39,7 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__)
// No support for async
#else
......
......@@ -8,6 +8,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
#endif
namespace gptq_marlin {
template <typename scalar_t>
......
......@@ -9,7 +9,6 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
......
#pragma once
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef hip_bfloat16 nv_bfloat16;
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16;
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
......@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
......@@ -169,7 +170,7 @@ def local_chat(
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
......
......@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
from flash_attn import flash_attn_func
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
try:
from flash_attn import flash_attn_func
except:
pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
......@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
attn_output = flash_attn_func(
query_states,
key_states,
value_states_padded,
softmax_scale=self.softmax_scale,
causal=True,
# for bsz = 1
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
max_input_len = q_len
context_attention_fwd(
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
o=attn_output,
b_start_loc=b_start_loc,
b_seq_len=b_seq_len,
max_input_len=max_input_len,
is_causal=True
)
if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output[:, :, : self.v_head_dim]
attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim
......@@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
return self.forward_windows(
hidden_states,
attention_mask,
......
......@@ -17,7 +17,10 @@ import logging
logger = logging.getLogger("dynamic_attention")
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
from flash_attn import flash_attn_func, flash_attn_with_kvcache
try:
from flash_attn import flash_attn_func, flash_attn_with_kvcache
except:
print("falsh attn not found")
import math
......
......@@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config
from typing import Dict, Tuple, Optional, Union
import numpy as np
#class KLinearBase(BaseInjectedModule, ABC):
class KLinearBase(ABC):
......@@ -176,6 +178,195 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearQ8(KLinearBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
group_size: int = 128, # 增大分组大小,减少量化噪声
percentile: float = 99.99, # 新增:对异常值进行截断的百分位数
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.compute_dtype = torch.float32
self.weight = None
self.weight_scale = None
self.weight_zero_point = None
self.bias = None
self.loaded = False
self.group_size = group_size
self.percentile = percentile
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
out_device = x.device
x = x.to(device=self.device, dtype=self.compute_dtype)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)
out = x @ weight_dequant.T
if self.has_bias:
out = out + self.bias
return out.to(dtype=orig_dtype, device=out_device)
def _dequantize_weight(self, q_matrix, scales, bits=8):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if not isinstance(q_matrix, torch.Tensor):
q_matrix = torch.tensor(q_matrix, dtype=torch.int8)
if not isinstance(scales, torch.Tensor):
scales = torch.tensor(scales, dtype=torch.float32)
# Convert to correct dtype if needed
if q_matrix.dtype != torch.int8:
q_matrix = q_matrix.to(torch.int8)
if scales.dtype != torch.float32:
scales = scales.to(torch.float32)
# For Q4, ensure the values stay within 4-bit range
if bits == 4:
q_matrix = torch.clamp(q_matrix, -7, 7)
# Get matrix shape
rows, cols = q_matrix.shape
# Convert to float32
dequant_matrix = q_matrix.to(torch.float32)
# Create broadcasted scales: reshape scales to [1, cols] for broadcasting
scales_broadcast = scales.view(1, cols)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix = dequant_matrix * scales_broadcast
return dequant_matrix
def _quantize_weight(self, matrix, bits=8):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if not isinstance(matrix, torch.Tensor):
matrix = torch.tensor(matrix, dtype=torch.float32)
# Convert to float32 if needed
if matrix.dtype != torch.float32:
matrix = matrix.to(torch.float32)
# Get matrix shape
rows, cols = matrix.shape
# Determine quantization parameters based on bits
if bits == 8:
# Q8: range is -127 to 127
max_int = 127
qtype = torch.int8
elif bits == 4:
# Q4: range is -7 to 7 (using 4-bit signed integers)
max_int = 7
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range
else:
raise ValueError("Quantization bits must be either 8 or 4")
# Initialize results and scale factors
q_matrix = torch.zeros_like(matrix, dtype=qtype)
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
# Initialize scale factors
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
# Calculate max absolute value for each column
max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)
# Handle zero columns (avoid division by zero)
zero_cols = max_abs_vals == 0
max_abs_vals[zero_cols] = 1.0
# Calculate scale factors for all columns at once
scales = max_abs_vals / max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast = scales.view(1, cols)
# Apply quantization to the entire matrix at once
q_matrix = torch.round(matrix / scales_broadcast).to(qtype)
# For Q4, clamp values to ensure they stay within 4-bit range
if bits == 4:
q_matrix = torch.clamp(q_matrix, -max_int, max_int)
return q_matrix, scales
def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):
if self.loaded: return
if device is None: device = self.device
if w is None: w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
try:
weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
except:
weight = w.to(dtype=self.compute_dtype)
self.has_bias = False
elif isinstance(w, tuple):
try:
weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
except:
weight = w[0].to(dtype=self.compute_dtype)
self.bias = w[1].to(dtype=self.compute_dtype).to(device)
self.has_bias = True
else:
raise ValueError("Invalid weight type")
self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)
self.weight = self.weight.to(device)
self.weight_scale = self.weight_scale.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
self.loaded = True
def unload(self):
self.weight = None
self.weight_scale = None
self.weight_zero_point = None
self._orig_weight = None
if self.has_bias:
self.bias = None
self.loaded = False
class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
......@@ -468,6 +659,7 @@ LINEAR_MAP = {
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):
......
......@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
DeepseekV2DecoderLayer,
DeepseekV2MoE,
)
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
......@@ -649,8 +650,8 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag:
causal_mask = None
else:
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
# print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
......@@ -673,6 +674,7 @@ class KDeepseekV2Model(BaseInjectedModule):
t_f = 0
for i, decoder_layer in enumerate(self.layers):
# print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
if self.transfer_map is not None and i in self.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.transfer_map[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment