Commit 91c7b53c authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Moving gemm kernels to awq

Removing torch dependencies in gemm kernel
Merged tinychat_pybind and pybind
Merged the load and setup script
parent ebe520c5
# -*- coding: utf-8 -*-
"""TinyChat Extension."""
import os
from torch.utils.cpp_extension import load
__all__ = ["_C"]
dirpath = os.path.dirname(__file__)
_C = load(
name="nunchaku_tinychat_C",
sources=[
f"{dirpath}/tinychat_pybind.cpp",
f"{dirpath}/quantization/gemv/gemv_cuda.cu",
f"{dirpath}/quantization/gemm/gemm_cuda.cu",
],
extra_cflags=["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++20"],
extra_cuda_cflags=[
"-O3",
"-std=c++20",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_HALF2_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=--allow-expensive-optimizations=true",
"--threads=8",
],
)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "interop/torch.h" #include "interop/torch.h"
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "kernels/awq/gemv_awq.h" #include "kernels/awq/gemv_awq.h"
#include "kernels/awq/gemm_cuda.h"
namespace nunchaku::ops { namespace nunchaku::ops {
...@@ -94,4 +95,26 @@ namespace nunchaku::ops { ...@@ -94,4 +95,26 @@ namespace nunchaku::ops {
return output; return output;
} }
torch::Tensor gemm_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros)
{
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
}; };
\ No newline at end of file
...@@ -4,10 +4,16 @@ ...@@ -4,10 +4,16 @@
#include "sana.h" #include "sana.h"
#include "ops.h" #include "ops.h"
#include "utils.h" #include "utils.h"
#include <torch/extension.h>
#include "awq/gemm_cuda.h"
#include "awq/gemv_awq.h"
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel.");
m.def("gemv_awq", &gemv_awq, "AWQ quantized GEMV kernel.");
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel") py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init", &QuantizedFluxModel::init,
......
#include <torch/extension.h>
torch::Tensor awq_gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scales,
torch::Tensor _zeros);
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "quantization/gemm/gemm_cuda.h"
#include "quantization/gemv/gemv_cuda.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel.");
m.def("awq_gemv_forward_cuda", &awq_gemv_forward_cuda, "AWQ quantized GEMV kernel.");
}
\ No newline at end of file
import os import os
import setuptools import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension): class CustomBuildExtension(BuildExtension):
def build_extensions(self): def build_extensions(self):
for ext in self.extensions: for ext in self.extensions:
...@@ -18,7 +16,6 @@ class CustomBuildExtension(BuildExtension): ...@@ -18,7 +16,6 @@ class CustomBuildExtension(BuildExtension):
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions() super().build_extensions()
if __name__ == "__main__": if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read() fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1]) version = eval(fp.strip().split()[-1])
...@@ -32,9 +29,11 @@ if __name__ == "__main__": ...@@ -32,9 +29,11 @@ if __name__ == "__main__":
"third_party/mio/include", "third_party/mio/include",
"third_party/spdlog/include", "third_party/spdlog/include",
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn", "third_party/Block-Sparse-Attention/csrc/block_sparse_attn",
"src/interop",
"src/kernels",
] ]
INCLUDE_DIRS = [ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS] INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS]
DEBUG = False DEBUG = False
...@@ -117,11 +116,16 @@ if __name__ == "__main__": ...@@ -117,11 +116,16 @@ if __name__ == "__main__":
"src/kernels/dwconv.cu", "src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu", "src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu", "src/kernels/gemm_f16.cu",
"src/kernels/awq/gemm_cuda.cu",
"src/kernels/awq/gemv_awq.cu", "src/kernels/awq/gemv_awq.cu",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
], ],
extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS}, extra_compile_args={
"gcc": GCC_FLAGS,
"msvc": MSVC_FLAGS,
"nvcc": NVCC_FLAGS,
},
include_dirs=INCLUDE_DIRS, include_dirs=INCLUDE_DIRS,
) )
......
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include "semaphore.h" #include "semaphore.h"
#include "gemm_cuda.h" #include "gemm_cuda.h"
#include "../dequantize.cuh" //#include "../../../nunchaku/csrc/quantization/dequantize.cuh"
#include "../../utils.cuh" #include "dequantize.cuh"
#include <torch/extension.h> #include <stdio.h>
#include "../dispatch_utils.h"
//#include "../../../nunchaku/csrc/utils.cuh"
#include "../utils.cuh"
#include <cuda_pipeline_primitives.h> #include <cuda_pipeline_primitives.h>
#define kInterleave 4 #define kInterleave 4
...@@ -25,8 +30,8 @@ ...@@ -25,8 +30,8 @@
#endif #endif
#define KERNEL_LAUNCH_CODE \ #define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \ int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
torch::Tensor _semaphores = torch::empty({num_mn_tiles}, options_int); \ Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \ auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \ constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
...@@ -301,7 +306,7 @@ __device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scale ...@@ -301,7 +306,7 @@ __device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scale
f162_t zero2 = f162f162(zero); f162_t zero2 = f162f162(zero);
f162_t loaded[4]; f162_t loaded[4];
dequantize_s4_to_f16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded)); dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++)
{ {
...@@ -758,7 +763,7 @@ __device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_sc ...@@ -758,7 +763,7 @@ __device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_sc
f162_t scale2 = f162f162(scale); f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero); f162_t zero2 = f162f162(zero);
f162_t loaded[4]; f162_t loaded[4];
dequantize_s4_to_f16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded)); dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++)
{ {
...@@ -949,25 +954,25 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_ ...@@ -949,25 +954,25 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_
} }
} }
torch::Tensor awq_gemm_forward_cuda( Tensor awq_gemm_forward_cuda(
torch::Tensor _in_feats, Tensor _in_feats,
torch::Tensor _kernel, Tensor _kernel,
torch::Tensor _scales, Tensor _scales,
torch::Tensor _zeros) Tensor _zeros)
{ {
std::vector<int64_t> output_shape = _in_feats.sizes().vec(); auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = _kernel.size(0) * kInterleave; output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1); int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1); int num_in_channels = _in_feats.size(-1);
auto options = auto options =
torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int = auto options_int =
torch::TensorOptions().dtype(torch::kInt32).device(_in_feats.device()); Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options); Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
int num_out_feats = _out_feats.numel() / _out_feats.size(-1); int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1); int num_out_channels = _out_feats.size(-1);
if (_in_feats.scalar_type() == at::ScalarType::Half) if (_in_feats.scalar_type() == Tensor::FP16)
{ {
using f16_t = half; using f16_t = half;
...@@ -1057,7 +1062,7 @@ torch::Tensor awq_gemm_forward_cuda( ...@@ -1057,7 +1062,7 @@ torch::Tensor awq_gemm_forward_cuda(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels); in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
} }
} }
else if (_in_feats.scalar_type() == at::ScalarType::BFloat16) else if (_in_feats.scalar_type() == Tensor::BF16)
{ {
using f16_t = __nv_bfloat16; using f16_t = __nv_bfloat16;
...@@ -1149,7 +1154,7 @@ torch::Tensor awq_gemm_forward_cuda( ...@@ -1149,7 +1154,7 @@ torch::Tensor awq_gemm_forward_cuda(
} }
else else
{ {
AT_ERROR("Unsupported input type"); throw std::runtime_error("Unsupported input type");
} }
return _out_feats; return _out_feats;
......
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros);
...@@ -349,6 +349,29 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) ...@@ -349,6 +349,29 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
#endif // ENABLE BF16 #endif // ENABLE BF16
template <typename f16_t>
__device__ __forceinline__
packed_as<f16_t, 2>::type
f162f162(f16_t x);
template <>
__device__ __forceinline__
packed_as<half, 2>::type
f162f162<half>(half x)
{
return __half2half2(x);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type
f162f162<__nv_bfloat16>(__nv_bfloat16 x)
{
return __bfloat162bfloat162(x);
}
# endif
template <typename To, typename Ti> template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val) __device__ inline To cuda_sum(Ti val)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment