Commit 873a35be authored by muyangli's avatar muyangli
Browse files

v0.1.4 ready to release


Co-authored-by: default avatarZhekai Zhang <sxtyzhangzk@gmail.com>
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
Co-authored-by: default avatarYujun Lin <16437040+synxlin@users.noreply.github.com>
parent d9cd6858
...@@ -3,7 +3,7 @@ from controlnet_aux import CannyDetector ...@@ -3,7 +3,7 @@ from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline from diffusers import FluxControlPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev")
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
......
...@@ -3,7 +3,7 @@ from diffusers import FluxControlPipeline ...@@ -3,7 +3,7 @@ from diffusers import FluxControlPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor from image_gen_aux import DepthPreprocessor
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev")
......
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", text_encoder_2=text_encoder_2, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from diffusers import FluxFillPipeline from diffusers import FluxFillPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png") image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png") mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
import torch import torch
from diffusers import SanaPipeline from diffusers import SanaPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
......
import torch import torch
from diffusers import SanaPAGPipeline from diffusers import SanaPAGPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8) transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained( pipe = SanaPAGPipeline.from_pretrained(
......
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__version__ = "0.1.3" __version__ = "0.1.4"
...@@ -9,9 +9,12 @@ ...@@ -9,9 +9,12 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder { class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public: public:
void init(bool use_fp4, bool bf16, int8_t deviceId) { void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel"); spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); if (offload) {
spdlog::info("Layer offloading enabled");
}
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward( torch::Tensor forward(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "interop/torch.h" #include "interop/torch.h"
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "kernels/awq/gemv_awq.h" #include "kernels/awq/gemv_awq.h"
#include "kernels/awq/gemm_cuda.h" #include "kernels/awq/gemm_awq.h"
namespace nunchaku::ops { namespace nunchaku::ops {
...@@ -72,7 +72,7 @@ namespace nunchaku::ops { ...@@ -72,7 +72,7 @@ namespace nunchaku::ops {
alpha, alpha,
getTensor(wcscales) getTensor(wcscales)
); );
Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
} }
torch::Tensor gemv_awq( torch::Tensor gemv_awq(
...@@ -97,12 +97,12 @@ namespace nunchaku::ops { ...@@ -97,12 +97,12 @@ namespace nunchaku::ops {
); );
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
torch::Tensor gemm_cuda( torch::Tensor gemm_awq(
torch::Tensor _in_feats, torch::Tensor _in_feats,
torch::Tensor _kernel, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
...@@ -115,8 +115,9 @@ namespace nunchaku::ops { ...@@ -115,8 +115,9 @@ namespace nunchaku::ops {
from_torch(_zeros.contiguous()) from_torch(_zeros.contiguous())
); );
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
......
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
#include "ops.h" #include "ops.h"
#include "utils.h" #include "utils.h"
#include <torch/extension.h> #include <torch/extension.h>
#include "awq/gemm_cuda.h"
#include "awq/gemv_awq.h"
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
...@@ -15,6 +13,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -15,6 +13,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init", &QuantizedFluxModel::init,
py::arg("use_fp4"), py::arg("use_fp4"),
py::arg("offload"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId")
) )
...@@ -75,7 +74,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -75,7 +74,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
; ;
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_cuda", nunchaku::ops::gemm_cuda) .def("gemm_awq", nunchaku::ops::gemm_awq)
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
; ;
......
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
template <typename T>
__device__ __forceinline__ void dequantize_s4_to_f16x2(T const &source, uint4 *result);
template <>
__device__ __forceinline__ void dequantize_s4_to_f16x2<half2>(half2 const &source, uint4 *result)
{
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
template <>
__device__ __forceinline__ void dequantize_s4_to_f16x2<__nv_bfloat162>(__nv_bfloat162 const &source, uint4 *result)
{
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const source_i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate bf16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
uint32_t i4s = source_i4s;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// This is the BF16 {-136, -136} represented as an integer.
// static constexpr uint32_t BF16_BIAS = 0xC308C308;
// This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
static constexpr uint32_t NEG_128 = 0xC300C300;
static constexpr uint32_t ONE = 0x3F803F80;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(ONE), "r"(NEG_128));
// Convert elt_23
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE), "r"(NEG_128));
// Convert elt_45
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(ONE), "r"(NEG_128));
// Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE), "r"(NEG_128));
}
\ No newline at end of file
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <stdio.h>
#include <torch/extension.h>
#include "gemv_cuda.h"
#include "../dequantize.cuh"
#include "../../utils.cuh"
#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128
// Reduce sum within the warp using the tree reduction algorithm.
template <typename fp_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(fp_t *psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = static_cast<float>(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor)
{
return (c + divisor - 1) / divisor;
}
template <typename f16_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const f16_t *inputs, const uint32_t *weight, const f16_t *scales, const f16_t *zeros, f16_t *outputs,
const int IC, const int OC)
{
using f162_t = typename packed_as<f16_t, 2>::type;
using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
alignas(16) f16_t local_inputs[kElemsPerThread];
alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
alignas(16) f16_t half_weight_buffer[kElemsPerThread];
alignas(16) f16_t dequantized_weight[kElemsPerThread * NPerBlock];
alignas(16) f16_t local_scale[NPerBlock];
alignas(16) f16_t local_scaled_zeros[NPerBlock];
accum_t psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride + (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const f16_t *scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const f16_t *zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const f16_t *inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4 *)(local_qweights)) =
*((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_f16x2(*reinterpret_cast<f162_t *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
f162_t w =
*reinterpret_cast<f162_t *>(
half_weight_buffer + (i + j * kShuffleContinous) * kShuffleBasicTile);
w = __hfma2(w, f162f162(local_scale[idx]), f162f162(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const f16_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step.
*((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
accum2_t prod = cuda_cast<accum2_t>(__hmul2(
*reinterpret_cast<f162_t *>(dequantized_weight + y * NPerBlock + x * 2),
f162f162(local_inputs[y])));
*reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) = prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
}
}
}
inputs_ptr += act_forward_step;
scale_ptr += scale_forward_step;
zeros_ptr += scale_forward_step;
}
warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<f16_t>(acc);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor awq_gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
std::vector<int64_t> output_shape = _in_feats.sizes().vec();
output_shape.back() = n;
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
AT_DISPATCH_REDUCED_FLOATING_TYPES(
_in_feats.scalar_type(),
"awq_gemv_forward_cuda",
[&]
{
using f16_t = typename to_cpp_t<scalar_t>::type;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto scaling_factors = reinterpret_cast<f16_t *>(_scaling_factors.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (group_size == 128)
{
switch (m)
{
case 1:
gemv_kernel<f16_t, N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 2:
gemv_kernel<f16_t, N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 3:
gemv_kernel<f16_t, N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 4:
gemv_kernel<f16_t, N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 5:
gemv_kernel<f16_t, N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 6:
gemv_kernel<f16_t, N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 7:
gemv_kernel<f16_t, N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
default:
throw std::runtime_error("Unsupported batch size for gemv kernel.\n");
}
}
else
{
throw std::runtime_error("Unsupported group size for gemv kernel.\n");
}
});
return _out_feats;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment