Commit 91c7b53c authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Moving gemm kernels to awq

Removing torch dependencies in gemm kernel
Merged tinychat_pybind and pybind
Merged the load and setup script
parent ebe520c5
# -*- coding: utf-8 -*-
"""TinyChat Extension."""
import os
from torch.utils.cpp_extension import load
__all__ = ["_C"]
dirpath = os.path.dirname(__file__)
_C = load(
name="nunchaku_tinychat_C",
sources=[
f"{dirpath}/tinychat_pybind.cpp",
f"{dirpath}/quantization/gemv/gemv_cuda.cu",
f"{dirpath}/quantization/gemm/gemm_cuda.cu",
],
extra_cflags=["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++20"],
extra_cuda_cflags=[
"-O3",
"-std=c++20",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_HALF2_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=--allow-expensive-optimizations=true",
"--threads=8",
],
)
......@@ -3,6 +3,7 @@
#include "interop/torch.h"
#include "kernels/zgemm/zgemm.h"
#include "kernels/awq/gemv_awq.h"
#include "kernels/awq/gemm_cuda.h"
namespace nunchaku::ops {
......@@ -94,4 +95,26 @@ namespace nunchaku::ops {
return output;
}
torch::Tensor gemm_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros)
{
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
};
\ No newline at end of file
......@@ -4,10 +4,16 @@
#include "sana.h"
#include "ops.h"
#include "utils.h"
#include <torch/extension.h>
#include "awq/gemm_cuda.h"
#include "awq/gemv_awq.h"
#include <pybind11/pybind11.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel.");
m.def("gemv_awq", &gemv_awq, "AWQ quantized GEMV kernel.");
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>())
.def("init", &QuantizedFluxModel::init,
......
#include <torch/extension.h>
torch::Tensor awq_gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scales,
torch::Tensor _zeros);
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "quantization/gemm/gemm_cuda.h"
#include "quantization/gemv/gemv_cuda.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel.");
m.def("awq_gemv_forward_cuda", &awq_gemv_forward_cuda, "AWQ quantized GEMV kernel.");
}
\ No newline at end of file
import os
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension):
def build_extensions(self):
for ext in self.extensions:
......@@ -18,7 +16,6 @@ class CustomBuildExtension(BuildExtension):
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions()
if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
......@@ -32,9 +29,11 @@ if __name__ == "__main__":
"third_party/mio/include",
"third_party/spdlog/include",
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn",
"src/interop",
"src/kernels",
]
INCLUDE_DIRS = [ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS]
INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS]
DEBUG = False
......@@ -117,11 +116,16 @@ if __name__ == "__main__":
"src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu",
"src/kernels/awq/gemm_cuda.cu",
"src/kernels/awq/gemv_awq.cu",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
],
extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS},
extra_compile_args={
"gcc": GCC_FLAGS,
"msvc": MSVC_FLAGS,
"nvcc": NVCC_FLAGS,
},
include_dirs=INCLUDE_DIRS,
)
......
......@@ -2,9 +2,14 @@
#include <cuda_bf16.h>
#include "semaphore.h"
#include "gemm_cuda.h"
#include "../dequantize.cuh"
#include "../../utils.cuh"
#include <torch/extension.h>
//#include "../../../nunchaku/csrc/quantization/dequantize.cuh"
#include "dequantize.cuh"
#include <stdio.h>
#include "../dispatch_utils.h"
//#include "../../../nunchaku/csrc/utils.cuh"
#include "../utils.cuh"
#include <cuda_pipeline_primitives.h>
#define kInterleave 4
......@@ -25,8 +30,8 @@
#endif
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
torch::Tensor _semaphores = torch::empty({num_mn_tiles}, options_int); \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
......@@ -301,7 +306,7 @@ __device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scale
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_f16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
......@@ -758,7 +763,7 @@ __device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_sc
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_f16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
dequantize_s4_to_fp16x2(*reinterpret_cast<f162_t *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
......@@ -949,25 +954,25 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_
}
}
torch::Tensor awq_gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scales,
torch::Tensor _zeros)
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros)
{
std::vector<int64_t> output_shape = _in_feats.sizes().vec();
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1);
auto options =
torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int =
torch::TensorOptions().dtype(torch::kInt32).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
auto options =
Tensor::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int =
Tensor::TensorOptions().dtype(Tensor::INT32).device(_in_feats.device());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1);
if (_in_feats.scalar_type() == at::ScalarType::Half)
if (_in_feats.scalar_type() == Tensor::FP16)
{
using f16_t = half;
......@@ -1057,7 +1062,7 @@ torch::Tensor awq_gemm_forward_cuda(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
}
else if (_in_feats.scalar_type() == at::ScalarType::BFloat16)
else if (_in_feats.scalar_type() == Tensor::BF16)
{
using f16_t = __nv_bfloat16;
......@@ -1149,7 +1154,7 @@ torch::Tensor awq_gemm_forward_cuda(
}
else
{
AT_ERROR("Unsupported input type");
throw std::runtime_error("Unsupported input type");
}
return _out_feats;
......
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros);
......@@ -349,6 +349,29 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
#endif // ENABLE BF16
template <typename f16_t>
__device__ __forceinline__
packed_as<f16_t, 2>::type
f162f162(f16_t x);
template <>
__device__ __forceinline__
packed_as<half, 2>::type
f162f162<half>(half x)
{
return __half2half2(x);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type
f162f162<__nv_bfloat16>(__nv_bfloat16 x)
{
return __bfloat162bfloat162(x);
}
# endif
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment