Commit c009512a authored by Azure-Tang's avatar Azure-Tang
Browse files

Merge branch 'main' into hip

parents c1f13a69 4f22d726
......@@ -7,75 +7,83 @@
* @LastEditTime : 2024-08-07 09:47:43
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#include "backend.h"
#include "task_queue.h"
#include "../vendors/vendor.h"
#include "llama.cpp/ggml-impl.h"
class CPUInfer {
public:
CPUInfer(int thread_num) {
backend_ = new Backend(thread_num - 1);
task_queue_ = new TaskQueue();
for (int i = 0; i < (1 << 16); ++i) {
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
}
}
~CPUInfer() {
delete backend_;
delete task_queue_;
}
template <typename Func, typename Obj, typename... Args>
void enqueue(Func f, Obj* obj, Args... args) {
task_queue_->enqueue([=]() {
std::invoke(f, *obj, args..., backend_);
});
}
void submit(std::pair<intptr_t, intptr_t> params) {
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
func(args);
}
void sync() {
task_queue_->sync();
}
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
}
static void sync_(void* cpu_infer_ptr) {
CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
cpuinfer->sync();
}
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
}
public:
Backend* backend_;
TaskQueue* task_queue_;
};
#endif
\ No newline at end of file
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#elif KTRANSFORMERS_USE_ROCM
#define __HIP_PLATFORM_AMD__
#include "vendors/hip.h"
#endif
#include "backend.h"
#include "task_queue.h"
#include "../vendors/vendor.h"
#include "llama.cpp/ggml-impl.h"
class CPUInfer {
public:
CPUInfer(int thread_num) {
backend_ = new Backend(thread_num - 1);
task_queue_ = new TaskQueue();
for (int i = 0; i < (1 << 16); ++i) {
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
}
}
~CPUInfer() {
delete backend_;
delete task_queue_;
}
template <typename Func, typename Obj, typename... Args>
void enqueue(Func f, Obj* obj, Args... args) {
task_queue_->enqueue([=]() {
std::invoke(f, *obj, args..., backend_);
});
}
void submit(std::pair<intptr_t, intptr_t> params) {
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
func(args);
}
void sync() {
task_queue_->sync();
}
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
void (*func)(void*) = (void (*)(void*))params.first;
void* args = (void*)params.second;
*((CPUInfer**)args) = this;
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
}
static void sync_(void* cpu_infer_ptr) {
CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
cpuinfer->sync();
}
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
}
public:
Backend* backend_;
TaskQueue* task_queue_;
};
#endif
\ No newline at end of file
## TODO
This directory can be removed after updating the version of `llama.cpp`.
\ No newline at end of file
#pragma once
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cuDeviceGet hipDeviceGet
#define CUdevice hipDevice_t
#define CUdeviceptr hipDeviceptr_t
#define cuMemUnmap hipMemUnmap
#define CUmemAccessDesc hipMemAccessDesc
#define cuMemAddressFree hipMemAddressFree
#define cuMemRelease hipMemRelease
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
#define cuMemCreate hipMemCreate
#define cuMemAddressReserve hipMemAddressReserve
#define cuMemMap hipMemMap
#define cuMemSetAccess hipMemSetAccess
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
#define CUmemAllocationProp hipMemAllocationProp
#define cuDeviceGetAttribute hipDeviceGetAttribute
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaKernelNodeParams hipKernelNodeParams
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaStreamEndCapture hipStreamEndCapture
#define cudaGraphDestroy hipGraphDestroy
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
#define cudaGraphNodeGetType hipGraphNodeGetType
#define cudaGraphGetNodes hipGraphGetNodes
#define cudaGraphExecUpdate hipGraphExecUpdate
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define cudaStreamBeginCapture hipStreamBeginCapture
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaHostFn_t hipHostFn_t
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef hip_bfloat16 nv_bfloat16;
#pragma once
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16;
#ifndef CPUINFER_VENDOR_VENDOR_H
#define CPUINFER_VENDOR_VENDOR_H
#ifdef USE_CUDA
#include "cuda.h"
#elif USE_HIP
#define __HIP_PLATFORM_AMD__
#include "hip.h"
#elif USE_MUSA
#include "musa.h"
#endif
#endif // CPUINFER_VENDOR_VENDOR_H
\ No newline at end of file
/**
* @Description :
* @Author : Azure-Tang
* @Description :
* @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Version : 0.2.2
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "custom_gguf/ops.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "gptq_marlin/ops.h"
#endif
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
......@@ -19,22 +19,53 @@
// namespace py = pybind11;
PYBIND11_MODULE(KTransformersOps, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
}, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
}, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
}, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
}, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
}, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
}, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
}, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
#ifdef KTRANSFORMERS_USE_CUDA
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
#endif
}
#include "ops.h"
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/library.h>
#include <torch/extension.h>
#include <torch/torch.h>
// namespace py = pybind11;
int test(){
return 5;
}
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("test", &test, "Function to test.");
}
......@@ -2,26 +2,55 @@
* @Description :
* @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 04:18:04
* @Version : 0.2.2
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*/
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/library.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
for(int i=0;i<blk_size;i++){
float scale = scales[block_id];
output[block_id * blk_size + i] = scale * qs[block_id * blk_size + i];
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block));
cur_block += 2;
for (int i = 0; i < ele_per_blk; i++){
output_blk[i] = scale * cur_block[i];
}
}
}
__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {
__half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block));
cur_block += 2;
for (int i = 0; i < ele_per_blk; i++) {
output_blk[i] = __float2half(scale * cur_block[i]);
}
}
}
__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block));
cur_block += 2;
for (int i = 0; i < ele_per_blk; i++) {
output_blk[i] = __float2bfloat16(scale * cur_block[i]);
}
}
}
......@@ -36,13 +65,13 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
__global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 82)));
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
......@@ -70,17 +99,85 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
int is = 0;
float dl, ml;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));
uint8_t sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);
scales = (uint8_t*)(data + block_id * blk_size + (is++));
sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);
shift += 2;
}
q += 32;
}
}
}
__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
int is = 0;
float dl, ml;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));
uint8_t sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);
scales = (uint8_t*)(data + block_id * blk_size + (is++));
sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);
shift += 2;
}
q += 32;
}
}
}
__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 108)));
const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
......@@ -126,19 +223,131 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
uint8_t m = 1;
uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);
for (int i = 0; i < 3; i++) {
aux[i] = 0;
for (int j = 0; j < 4; j++) {
aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);
}
}
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
int is = 0;
float dl;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));
}
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));
}
shift += 2;
m <<= 1;
}
q += 32;
}
}
}
__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
uint8_t m = 1;
uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);
for (int i = 0; i < 3; i++) {
aux[i] = 0;
for (int j = 0; j < 4; j++) {
aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);
}
}
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
int is = 0;
float dl;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));
}
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));
}
shift += 2;
m <<= 1;
}
q += 32;
}
}
}
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
// const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * 144 + 0)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * 144 + 2)));
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < blk_size; j += 64) {
for (int j = 0; j < ele_per_blk; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
......@@ -151,13 +360,61 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
// const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 2)));
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < ele_per_blk; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l] >> 4) - m2);
q += 32; is += 2;
}
}
}
__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
// const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0;
uint8_t sc, m;
for (int j = 0; j < ele_per_blk; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l] >> 4) - m2);
q += 32; is += 2;
}
}
}
__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
......@@ -180,46 +437,165 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 208)));
__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);
for (int j = 0; j < 256; j += 64) {
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
}
}
__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);
for (int j = 0; j < 256; j += 64) {
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);
for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
}
}
__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
//if (blk_size == 256){
for (int n = 0; n < blk_size; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
output_blk[l + 0] = d * sc[is + 0] * q1;
output_blk[l + 32] = d * sc[is + 2] * q2;
output_blk[l + 64] = d * sc[is + 4] * q3;
output_blk[l + 96] = d * sc[is + 6] * q4;
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
for (int n = 0; n < ele_per_blk; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
output_blk[l + 0] = d * sc[is + 0] * q1;
output_blk[l + 32] = d * sc[is + 2] * q2;
output_blk[l + 64] = d * sc[is + 4] * q3;
output_blk[l + 96] = d * sc[is + 6] * q4;
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
for (int n = 0; n < ele_per_blk; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
output_blk[l + 0] = __float2half(d * sc[is + 0] * q1);
output_blk[l + 32] = __float2half(d * sc[is + 2] * q2);
output_blk[l + 64] = __float2half(d * sc[is + 4] * q3);
output_blk[l + 96] = __float2half(d * sc[is + 6] * q4);
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
for (int n = 0; n < ele_per_blk; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1);
output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2);
output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3);
output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4);
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
}
}
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2));
__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
......@@ -236,152 +612,267 @@ __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_si
}
}
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
__half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
for (int ib = 0; ib < 8; ++ib) {
const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);
const float dl = d * (ls - 32);
for (int j = 0; j < 16; ++j) {
output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]);
output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]);
}
output_blk += 32;
qs += 16;
}
}
}
__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);
for (int ib = 0; ib < 8; ++ib) {
const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);
const float dl = d * (ls - 32);
for (int j = 0; j < 16; ++j) {
output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]);
output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]);
}
output_blk += 32;
qs += 16;
}
}
}
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
// create gpu
auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto scales_gpu = torch::empty({{num_blocks, 1}}, options_scales);
auto qs_gpu = torch::empty({num_blocks, 32}, options_qs);
// read on cpu
options_scales = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCPU);
options_qs = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
// // reinterpret
auto scales = torch::from_blob(data.data_ptr(), {num_blocks, 1 + 16}, options_scales).slice(1, 0, 1);
auto qs = torch::from_blob(data.data_ptr(), {num_blocks, 2 + 32}, options_qs).slice(1, 2);
auto scales_f32 = scales.to(torch::kFloat32);
scales_gpu.copy_(scales_f32, false);
qs_gpu.copy_(qs, false);
// Create output tensor
auto output = torch::zeros_like(qs, torch::dtype(torch::kFloat32).device(device));
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({ num_bytes }, options);
// Launch kernel
dequantize_q8_0_kernel<<< 512, 256 >>>(
output.data_ptr<float>(), scales_gpu.data_ptr<float>(), qs_gpu.data_ptr<int8_t>(), num_blocks, 32);
cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device));
switch (target_dtype) {
case torch::kFloat16:
dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kBFloat16:
dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kFloat32:
dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device) {
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
// data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size;
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false);
cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
// dequantize_q6_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), 256, num_blocks);
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
switch (target_dtype) {
case torch::kFloat16:
dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kBFloat16:
dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kFloat32:
dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false);
cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
switch (target_dtype) {
case torch::kFloat16:
dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kBFloat16:
dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kFloat32:
dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) {
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
// data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size;
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false);
cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q4_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), 256, num_blocks);
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
switch (target_dtype) {
case torch::kFloat16:
dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kBFloat16:
dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kFloat32:
dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false);
cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
switch (target_dtype) {
case torch::kFloat16:
dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kBFloat16:
dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kFloat32:
dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false);
cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
switch (target_dtype) {
case torch::kFloat16:
dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kBFloat16:
dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kFloat32:
dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
auto data_gpu = torch::empty({num_bytes}, options);
data_gpu.copy_(data, false);
cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);
//data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_iq4_xs_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
switch (target_dtype) {
case torch::kFloat16:
dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kBFloat16:
dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break;
case torch::kFloat32:
dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break;
default:
printf("target type not support\n");
exit(0);
}
cudaDeviceSynchronize();
return output;
}
/**
* @Description :
* @Description :
* @Author : Azure-Tang
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:48:46
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
......@@ -13,10 +13,10 @@
#include <torch/extension.h>
#include <torch/torch.h>
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
import os
import sys
sys.path.insert(0,"/home/zbx/ktransformers")
from ktransformers.util.custom_gguf import GGUFLoader
import torch
gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
gguf_loader_2 = GGUFLoader("/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/")
torch.set_default_dtype(torch.bfloat16)
tensor_1 = gguf_loader_1.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda")
tensor_2 = gguf_loader_2.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda")
print(tensor_1[0, -64:])
print(tensor_2[0, -64:])
\ No newline at end of file
......@@ -90,7 +90,7 @@ def marlin_quantize(
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are
......@@ -107,7 +107,7 @@ def marlin_quantize(
marlin_scale_perm_single[num_bits])
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
......
......@@ -11,8 +11,7 @@ def get_pack_factor(num_bits):
return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
assert q_w.shape == w_ref.shape
def permute_rows(q_w: torch.Tensor, group_size: int):
orig_device = q_w.device
k_size, _ = q_w.shape
......@@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
......@@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
# Restore original shapes
if group_size < size_k:
......@@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
......@@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
q_w, g_idx, rand_perm = permute_rows(q_w, group_size)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
g_idx.to(device=orig_device),
......
......@@ -10,6 +10,8 @@
#include "kvcache.h"
#include <chrono>
void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,
float *attn_lse, int batch_size,
Backend *backend) {
......
......@@ -9,6 +9,9 @@
**/
#include "kvcache.h"
#include <chrono>
void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {
// Timer start
auto start = std::chrono::high_resolution_clock::now();
......
......@@ -10,6 +10,8 @@
#include "kvcache.h"
#include <chrono>
void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,
int block_idx, Backend *backend) {
// Timer start
......
......@@ -10,6 +10,8 @@
#include "kvcache.h"
#include <chrono>
std::string ggml_type_to_string(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
......
# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
from typing import Tuple
import torch
import triton
import triton.language as tl
from triton import Config
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
Args:
x_ptr (triton.Pointer): Pointer to the input tensor.
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
Returns:
None
"""
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
with torch.cuda.device(x.device):
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c
\ No newline at end of file
......@@ -28,8 +28,9 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
......@@ -53,7 +54,7 @@ default_optimize_rules = {
def local_chat(
model_path: str | None = None,
optimize_rule_path: str = None,
optimize_config_path: str = None,
gguf_path: str | None = None,
max_new_tokens: int = 300,
cpu_infer: int = Config().cpu_infer,
......@@ -61,9 +62,9 @@ def local_chat(
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
chunk_prefill_size: int = 8192
):
torch.set_grad_enabled(False)
Config().cpu_infer = cpu_infer
......@@ -94,12 +95,12 @@ def local_chat(
config, trust_remote_code=True, attn_implementation="flash_attention_2"
)
if optimize_rule_path is None:
if optimize_config_path is None:
if config.architectures[0] in default_optimize_rules:
print("using default_optimize_rule for", config.architectures[0])
optimize_rule_path = default_optimize_rules[config.architectures[0]]
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = input(
optimize_config_path = input(
"please input the path of your rule file(yaml file containing optimize rules):"
)
......@@ -107,18 +108,18 @@ def local_chat(
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):"
)
optimize_and_load_gguf(model, optimize_rule_path, gguf_path, config)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
try:
model.generation_config = GenerationConfig.from_pretrained(model_path)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
model.generation_config = gen_config
model.generation_config = GenerationConfig.from_pretrained(model_path)
except Exception as e:
print(f"generation config can't auto create, make default. Message: {e}")
gen_config = GenerationConfig(
temperature=0.6,
top_p=0.95,
do_sample=True
)
model.generation_config = gen_config
# model.generation_config = GenerationConfig.from_pretrained(model_path)
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
......@@ -167,13 +168,17 @@ def local_chat(
if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think
)
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
)
if __name__ == "__main__":
fire.Fire(local_chat)
\ No newline at end of file
fire.Fire(local_chat)
......@@ -51,13 +51,34 @@ class StaticCache(transformers.StaticCache):
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
# key_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.qk_rope_head_dim + config.qk_nope_head_dim)
# value_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, config.v_head_dim)
key_shape = (max_batch_size, 1, self.max_cache_len, config.qk_rope_head_dim)
value_shape = (max_batch_size, 1, self.max_cache_len, config.kv_lora_rank)
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
# TODO: support real page table
self.page_table_map = dict()
self.page_table_list = []
for idx in range(config.num_hidden_layers):
if isinstance(device, dict):
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
if target_device not in self.page_table_map:
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
for seq_id in range(max_batch_size):
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
self.page_table_map[target_device] = page_table
self.page_table_list.append(self.page_table_map[target_device])
self.is_MLA = True
self.is_page = True
else:
key_shape = cache_shape
value_shape = cache_shape
self.is_MLA = False
self.past_tokens = []
self.num_hidden_layers = config.num_hidden_layers
......@@ -68,10 +89,17 @@ class StaticCache(transformers.StaticCache):
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
if self.is_MLA:
new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = None
torch._dynamo.mark_static_address(new_layer_key_cache)
else:
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
self.past_tokens.append(0)
......@@ -104,11 +132,19 @@ class StaticCache(transformers.StaticCache):
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
#print(cache_position)
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.past_tokens[layer_idx] += cache_position.size(0)
return k_out, v_out
#print(cache_position)
if self.is_MLA:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
......@@ -134,8 +170,21 @@ class StaticCache(transformers.StaticCache):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0
def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
if self.is_MLA:
k_cache = self.key_cache[layer_idx]
k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()
else:
self.key_cache[layer_idx][..., start_pos:, :].zero_()
self.value_cache[layer_idx][..., start_pos:, :].zero_()
self.past_tokens[layer_idx] = start_pos
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""
return self.max_cache_len
\ No newline at end of file
return self.max_cache_len
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment