Commit bf4adfeb authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Support Turing (sm_75) architecture

parent 3ef186fd
......@@ -11,6 +11,9 @@ class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::
public:
void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel on device {}", deviceId);
if (!bf16) {
spdlog::info("Use FP16 model");
}
if (offload) {
spdlog::info("Layer offloading enabled");
}
......@@ -20,6 +23,11 @@ public:
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
bool isBF16() {
checkModel();
return net->dtype == Tensor::BF16;
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
......
......@@ -34,6 +34,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("setAttentionImpl", &QuantizedFluxModel::setAttentionImpl)
.def("isBF16", &QuantizedFluxModel::isBF16)
;
py::class_<QuantizedSanaModel>(m, "QuantizedSanaModel")
.def(py::init<>())
......@@ -98,5 +99,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("set_cuda_stack_limit", nunchaku::utils::set_cuda_stack_limit)
.def("disable_memory_auto_release", nunchaku::utils::disable_memory_auto_release)
.def("trim_memory", nunchaku::utils::trim_memory)
.def("set_faster_i2f_mode", nunchaku::utils::set_faster_i2f_mode)
;
}
......@@ -2,6 +2,7 @@
#include "common.h"
#include "Tensor.h"
#include "kernels/zgemm/zgemm.h"
namespace nunchaku::utils {
......@@ -30,4 +31,9 @@ namespace nunchaku::utils {
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void set_faster_i2f_mode(std::string mode) {
spdlog::info("Set fasteri2f mode to {}", mode);
kernels::set_faster_i2f_mode(mode);
}
};
\ No newline at end of file
......@@ -30,7 +30,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
def __init__(self, m: QuantizedFluxModel, device: str | torch.device):
super(NunchakuFluxTransformerBlocks, self).__init__()
self.m = m
self.dtype = torch.bfloat16
self.dtype = torch.bfloat16 if m.isBF16() else torch.float16
self.device = device
@staticmethod
......@@ -188,13 +188,13 @@ class EmbedND(nn.Module):
def load_quantized_module(
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False, bf16: bool = True
) -> QuantizedFluxModel:
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedFluxModel()
cutils.disable_memory_auto_release()
m.init(use_fp4, offload, True, 0 if device.index is None else device.index)
m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
m.load(path)
return m
......@@ -241,6 +241,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if isinstance(device, str):
device = torch.device(device)
offload = kwargs.get("offload", False)
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
......@@ -258,7 +259,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
elif "lora" in k:
new_quantized_part_sd[k] = v
transformer._quantized_part_sd = new_quantized_part_sd
m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4", offload=offload)
m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4", offload=offload, bf16=torch_dtype == torch.bfloat16)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
......
......@@ -47,12 +47,12 @@ def get_sm_targets() -> list[str]:
sm = f"{capability[0]}{capability[1]}"
if sm == "120" and support_sm120:
sm = "120a"
assert sm in ["80", "86", "89", "120a"], f"Unsupported SM {sm}"
assert sm in ["75", "80", "86", "89", "120a"], f"Unsupported SM {sm}"
if sm not in ret:
ret.append(sm)
else:
assert install_mode == "ALL"
ret = ["80", "86", "89"]
ret = ["75", "80", "86", "89"]
if support_sm120:
ret.append("120a")
return ret
......@@ -142,6 +142,7 @@ if __name__ == "__main__":
*ncond("src/FluxModel.cpp"),
*ncond("src/SanaModel.cpp"),
"src/Serialization.cpp",
"src/Module.cpp",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
......@@ -160,6 +161,7 @@ if __name__ == "__main__":
"src/kernels/zgemm/gemm_w4a4.cu",
"src/kernels/zgemm/gemm_w4a4_test.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu",
......
......@@ -699,7 +699,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states };
}
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : offload(offload) {
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : dtype(dtype), offload(offload) {
for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
......
......@@ -143,6 +143,8 @@ public:
void setAttentionImpl(AttentionImpl impl);
public:
const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
......
......@@ -52,13 +52,14 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->device);
dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
Module::loadParam(key, dst, src);
if (key == "lora_down") {
const int new_rank = dst.shape[0];
this->lora_rank = new_rank;
}
} else {
dst.copy_(src);
Module::loadParam(key, dst, src);
}
} else {
Module::loadParam(key, dst, src);
......@@ -143,16 +144,18 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->device);
dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
Module::loadParam(key, dst, src);
this->lora_rank = dst.shape[1];
this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
} else {
dst.copy_(src);
Module::loadParam(key, dst, src);
}
} else if (key == "wcscales") {
assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad);
dst = src.copy(this->device);
dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
Module::loadParam(key, dst, src);
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
......@@ -160,7 +163,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) {
dst.copy_(src);
Module::loadParam(key, dst, src);
} else {
assert(false);
}
......
#include "common.h"
#include "Module.h"
#include "kernels/misc_kernels.h"
void Module::copyWithCast(Tensor dst, Tensor src) {
assert(dst.is_contiguous());
assert(dst.device().type == Device::CUDA);
if (src.device().type == Device::CUDA && src.device().idx == dst.device().idx) {
nunchaku::kernels::cast(src, dst);
} else {
Tensor tmp;
tmp.buffer = dst.buffer;
tmp.shape = dst.shape;
tmp.scalarType = src.scalarType;
tmp.copy_(src);
nunchaku::kernels::cast(tmp, dst);
}
}
......@@ -131,10 +131,23 @@ public:
m->enabledLazyLoad = val;
});
}
void setAutoCastFP16(bool val) {
traverse([val](Module *m) {
m->enabledAutoCastFP16 = val;
});
}
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) {
dst.copy_(src);
static const std::set<Tensor::ScalarType> whitelist = {
Tensor::FP16,
Tensor::BF16,
};
if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) && whitelist.contains(src.scalar_type())) {
copyWithCast(dst, src);
} else {
dst.copy_(src);
}
}
struct ChildrenRegisterHelper {
......@@ -187,6 +200,9 @@ protected:
}
}
private:
void copyWithCast(Tensor dst, Tensor src);
public:
Module *parent = nullptr;
std::string name = "";
......@@ -194,6 +210,7 @@ public:
std::map<std::string, Param> params;
bool enabledLazyLoad = false;
bool enabledAutoCastFP16 = true;
};
struct LayerOffloadHelper {
......
......@@ -319,6 +319,10 @@ __device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scale
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G, int SPLITK>
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
......@@ -776,6 +780,10 @@ __device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_sc
template <typename f16_t, int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int M, int N, int K)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
......
......@@ -112,6 +112,13 @@ __global__ void gemv_kernel(
const half_t* inputs, const uint32_t* weight, const half_t* scales, const half_t* zeros, half_t* outputs,
const int IC, const int OC)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr(std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch();
return;
}
#endif
using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type;
......@@ -273,46 +280,45 @@ Tensor gemv_awq(
int k,
int group_size)
{
using half_t = __nv_bfloat16;
// using half_t = half;
return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() {
assert(isTypeMatch<half_t>(_in_feats.dtype()));
assert(isTypeMatch<half_t>(_in_feats.dtype()));
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = n;
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = n;
auto in_feats = reinterpret_cast<half_t*>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t*>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t*>(_scaling_factors.data_ptr<half_t>());
auto in_feats = reinterpret_cast<half_t*>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t*>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t*>(_scaling_factors.data_ptr<half_t>());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
half_t * out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
half_t * out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
constexpr int GROUP_SIZE = 64;
constexpr int GROUP_SIZE = 64;
assert(m > 0 && m < 8);
assert(group_size == GROUP_SIZE);
assert(m > 0 && m < 8);
assert(group_size == GROUP_SIZE);
dispatchVal(m, std::make_integer_sequence<int, 8>(), [&]<int M>() {
if constexpr (M == 0) {
assert(false);
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
checkCUDA(cudaGetLastError());
}
});
dispatchVal(m, std::make_integer_sequence<int, 8>(), [&]<int M>() {
if constexpr (M == 0) {
assert(false);
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
checkCUDA(cudaGetLastError());
}
return _out_feats;
});
return _out_feats;
}
......@@ -15,6 +15,20 @@ inline auto dispatchFloat(Tensor::ScalarType scalarType, F &&func) {
return func.template operator()<float>();
default:
assert(false);
throw std::invalid_argument("scalarType is not a floating type");
}
}
template<typename F>
inline auto dispatchFloat16(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__nv_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
default:
assert(false);
throw std::invalid_argument("scalarType is not a float16 type");
}
}
......
......@@ -41,13 +41,13 @@ Tensor gemm_f16(Tensor input, // FP16
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75>;
// cutlass::gemm::GemmShape<128, 128, 64>,
// cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
// cutlass::epilogue::thread::LinearCombination<
// ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
// ElementAccumulator, ElementComputeEpilogue>,
// cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
......
......@@ -247,6 +247,10 @@ void cast(Tensor input, Tensor output) {
assert(output.is_contiguous());
assert(input.shape.dataExtent == output.shape.dataExtent);
if (input.data_ptr() == output.data_ptr()) {
assert(input.scalar_size() == output.scalar_size());
}
auto stream = getCurrentCUDAStream();
dispatch(input.scalar_type(), [&]<typename input_t>() {
......
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <cassert>
#include <cstdint>
#include <cfloat>
#include <type_traits>
#include <cstdio>
#include <cuda_fp16.h>
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
__device__ __forceinline__
static void trap_unsupported_arch() {
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
printf("This kernel is not supported on your GPU\n");
}
__syncthreads();
__nanosleep(1000000);
__trap();
}
#if defined(ENABLE_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__device__ __forceinline__
static __nv_bfloat162 __hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c) {
trap_unsupported_arch();
return __nv_bfloat162(0.0f, 0.0f);
}
#endif
template<typename T> struct num_elems;
template <> struct num_elems<float> { static constexpr int value = 1; };
template <> struct num_elems<float2> { static constexpr int value = 2; };
......@@ -409,6 +429,9 @@ __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#else
assert(false);
return 0;
#endif
}
#endif
......
......@@ -60,6 +60,12 @@ public:
using typename AttentionConfig::epilogue_half_t;
using typename AttentionConfig::epilogue_half2_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static constexpr bool IS_SM80 = true;
#else
static constexpr bool IS_SM80 = false;
#endif
struct GEMMConfig {
static constexpr int BLOCK_M = AttentionConfig::BLOCK_M;
static constexpr int BLOCK_N = AttentionConfig::HEAD_DIM;
......@@ -182,33 +188,9 @@ public:
__device__ __forceinline__
static packed_fpsum_t mma_f16xf16_f16(packed_fpsum_t a, packed_fpsum_t b, packed_fpsum_t psum) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(psum.x), "=r"(psum.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(psum.x), "r"(psum.y)
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(psum.z), "=r"(psum.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.z), "r"(b.w),
"r"(psum.z), "r"(psum.w)
);
return psum;
uint2 out1 = mma_m16n8k16_f16f16f16f16(a, uint2(b.x, b.y), uint2(psum.x, psum.y));
uint2 out2 = mma_m16n8k16_f16f16f16f16(a, uint2(b.z, b.w), uint2(psum.z, psum.w));
return packed_fpsum_t{out1.x, out1.y, out2.x, out2.y};
}
// set nan values to -inf
......@@ -224,6 +206,12 @@ public:
return __hmax2(input, half2_t(neginf, neginf));
}
__device__ __forceinline__
static float fix_nan(float input) {
static constexpr float neginf = -std::numeric_limits<float>::infinity();
return fmaxf(input, neginf);
}
__device__ __forceinline__
static packed_fpsum_t fix_nan(packed_fpsum_t input) {
input.x = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.x)));
......@@ -233,6 +221,15 @@ public:
return input;
}
__device__ __forceinline__
static packed_f32psum_t fix_nan(packed_f32psum_t input) {
#pragma unroll
for (int i = 0; i < 8; i++) {
input.data[i] = fix_nan(input.data[i]);
}
return input;
}
__device__ __forceinline__
static qk_warp compute_qk(q_warp Q, k_warp K) {
qk_warp QK;
......@@ -259,8 +256,13 @@ public:
for (int d = 0; d < WARP_D_TILES; d++) {
psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
}
psum = fix_nan(psum);
QK[m * WARP_K_TILES_QK + k] = packed_fp16_to_fp32(psum);
if constexpr (IS_SM80) {
psum = fix_nan(psum);
QK[m * WARP_K_TILES_QK + k] = packed_fp16_to_fp32(psum);
} else {
QK[m * WARP_K_TILES_QK + k] = fix_nan(packed_fp16_to_fp32(psum));
}
#endif
}
......@@ -586,7 +588,7 @@ public:
L.fill(make_float2(0.0f, 0.0f));
M.fill(make_float2(neginf, neginf));
static constexpr int SHMEM_TILES = 4;
static constexpr int SHMEM_TILES = IS_SM80 ? 4 : 7;
static_assert(SHMEM_TILES <= Q.size());
using q_shmem_t = packed_q_t[NUM_WARPS][SHMEM_TILES][WARP_SIZE];
__shared__ q_shmem_t Q_shmem;
......@@ -610,6 +612,12 @@ public:
Q[Q.size() - 1 - i] = load<true>(&Q_shmem[warpId][i][laneId]);
}
if constexpr (!IS_SM80) {
if (k1 % 2 == 1) {
__syncthreads();
}
}
if (alwaysfalse) {
dummy = clock();
}
......@@ -638,6 +646,8 @@ public:
load_k(ptr_k, k1+1, K, k1+1 < ntokens_kv / WARP_K);
// if (alwaysfalse) {
// dummy = clock();
// }
......@@ -666,6 +676,7 @@ public:
template<typename Epilogue>
struct attention_fp16_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr int SHMEM_SIZE = 0; // sizeof(q_shmem_t);
__device__
......
......@@ -26,7 +26,7 @@
namespace nunchaku::kernels {
template<bool bf16>
template<bool bf16, bool faster_i2f = false>
class GEMMConfig_W4A4 {
public:
// BE CAREFUL: weights need to be repacked when the tiling size changes
......@@ -40,13 +40,17 @@ public:
static constexpr int INSN_N = 16;
static constexpr int INSN_K = 64;
// faster i2f conversion on sm_75
// may generate incorrect results in certain circumstances
static constexpr bool FASTER_I2F = faster_i2f;
using half_t = typename std::conditional_t<bf16, __nv_bfloat16, half>;
using half2_t = typename std::conditional_t<bf16, __nv_bfloat162, half2>;
};
using GEMMConfig_W4A4_FP16 = GEMMConfig_W4A4<false>;
using GEMMConfig_W4A4_BF16 = GEMMConfig_W4A4<true>;
using GEMMConfig_W4A4_FP16_FasterI2F = GEMMConfig_W4A4<false, true>;
class GEMMConfig_W8A8 {
public:
......@@ -199,85 +203,24 @@ public:
static packed_f32psum_t mma_f16xf16_f32(packed_fpsum_t a, packed_fpsum_t b, packed_f32psum_t psum) {
static_assert(std::is_same_v<half_t, half> || std::is_same_v<half_t, __nv_bfloat16>);
if constexpr (std::is_same_v<half_t, half>) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
);
}
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
);
}
static constexpr bool is_bf16 = std::is_same_v<half_t, __nv_bfloat16>;
uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>(
bit_cast<uint4>(a),
bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])),
bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3])));
uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>(
bit_cast<uint4>(a),
bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])),
bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7])));
psum.data[0] = bit_cast<float>(out1.x);
psum.data[1] = bit_cast<float>(out1.y);
psum.data[2] = bit_cast<float>(out1.z);
psum.data[3] = bit_cast<float>(out1.w);
psum.data[4] = bit_cast<float>(out2.x);
psum.data[5] = bit_cast<float>(out2.y);
psum.data[6] = bit_cast<float>(out2.z);
psum.data[7] = bit_cast<float>(out2.w);
return psum;
}
......@@ -400,7 +343,19 @@ public:
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
template<bool FAST_I2F = false, typename F>
struct i2f_normal {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
return make_float2(__int2float_rn(x), __int2float_rn(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
return float22half2<half2_t>(int2float2(x, y));
}
};
template<typename i2f = i2f_normal, typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
......@@ -430,30 +385,11 @@ public:
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
auto scale_fma_normal = [&]() ALWAYSINLINE {
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
};
// should be faster on sm_80
auto scale_fma_fast = [&]() ALWAYSINLINE {
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[0]), int2float_fast(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[2]), int2float_fast(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[4]), int2float_fast(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(int2float_fast(psum.data[6]), int2float_fast(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
if constexpr (FAST_I2F) {
scale_fma_fast();
} else {
scale_fma_normal();
}
#else
scale_fma_normal();
#endif
fsum.data[0] = __hfma2(i2f::int2half2(psum.data[0], psum.data[1]), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(i2f::int2half2(psum.data[2], psum.data[3]), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(i2f::int2half2(psum.data[4], psum.data[5]), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(i2f::int2half2(psum.data[6], psum.data[7]), __hmul2(asy[i], ws2), fsum.data[3]);
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
......@@ -461,7 +397,7 @@ public:
}
}
template<typename F>
template<typename i2f = i2f_normal, typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, f32psum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
......@@ -490,10 +426,10 @@ public:
packed_psum_t psum = getpsum(i, j);
fma2(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1])), asx[i] * ws1, fsum.data[0], fsum.data[1]);
fma2(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3])), asy[i] * ws1, fsum.data[2], fsum.data[3]);
fma2(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5])), asx[i] * ws2, fsum.data[4], fsum.data[5]);
fma2(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7])), asy[i] * ws2, fsum.data[6], fsum.data[7]);
fma2(i2f::int2float2(psum.data[0], psum.data[1]), asx[i] * ws1, fsum.data[0], fsum.data[1]);
fma2(i2f::int2float2(psum.data[2], psum.data[3]), asy[i] * ws1, fsum.data[2], fsum.data[3]);
fma2(i2f::int2float2(psum.data[4], psum.data[5]), asx[i] * ws2, fsum.data[4], fsum.data[5]);
fma2(i2f::int2float2(psum.data[6], psum.data[7]), asy[i] * ws2, fsum.data[6], fsum.data[7]);
}
}
}
......@@ -863,11 +799,36 @@ public:
using Base::pack_wscales; \
using Base::apply_act;
template<typename kernel>
constexpr int min_arch() {
if constexpr (requires {kernel::MIN_ARCH;}) {
return kernel::MIN_ARCH;
} else {
return 0;
}
}
template<typename kernel>
constexpr int max_arch() {
if constexpr (requires {kernel::MAX_ARCH;}) {
return kernel::MAX_ARCH;
} else {
return INT_MAX;
}
}
template<typename kernel, typename ...T>
__global__
static void invoke_kernel(T ...args) {
#ifdef __CUDA_ARCH__
if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) {
kernel()(args...);
} else {
trap_unsupported_arch();
}
#else
// ???
kernel()(args...);
#endif
}
template<typename T>
......
......@@ -195,6 +195,182 @@ static T movmatrix(T x) {
return x;
}
namespace mma_helper {
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
};
__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#endif
return d;
}
template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"C"(mma_helper::f16bf16<is_bf16>::value)
);
#else
static_assert(!is_bf16);
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#endif
return d;
}
template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
"C"(AType::value),
"C"(BType::value)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2)
"C"(AType::value),
"C"(BType::value)
);
#endif
return d;
}
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
......@@ -386,4 +562,47 @@ static To bit_cast(const From &input) {
return *reinterpret_cast<const To *>(&input);
}
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
__device__ __forceinline__
static half2 int2half2_fast_8192(int x, int y) {
uint32_t ival;
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
}
// val in [-4096, 4095], steps of 8, round to nearest
__device__ __forceinline__
static half2 int2half2_fast_4096_rn(int x, int y) {
// x = max(min(x, 4095), -4096);
// y = max(min(y, 4095), -4096);
// TODO: round to even?
x = x * 8192 + 32768;
y = y * 8192 + 32768;
uint32_t ival;
uint32_t hval;
// ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
}
// val in [-512, 511]
__device__ __forceinline__
static half2 int2half2_fast_512(int x, int y) {
uint32_t ival;
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -3,14 +3,32 @@
namespace nunchaku::kernels {
// for sm_75 only
struct FasterI2FMode {
enum Mode {
Disabled = 0,
Enabled,
Always,
};
inline static Mode mode = Disabled;
static bool check(bool act_unsigned);
};
template<typename F>
static void invoke_launch(Tensor::ScalarType dtype, F &&launch) {
if (dtype == Tensor::FP16) {
launch.template operator()<GEMMConfig_W4A4_FP16>();
} else if (dtype == Tensor::BF16) {
launch.template operator()<GEMMConfig_W4A4_BF16>();
static void invoke_launch(Tensor::ScalarType dtype, bool use_fp4, bool fasterI2F, F &&launch) {
if (fasterI2F && dtype == Tensor::FP16) {
launch.template operator()<GEMMConfig_W4A4_FP16_FasterI2F, false>();
} else {
assert(false);
dispatchBool(use_fp4, [&]<bool USE_FP4>() {
if (dtype == Tensor::FP16) {
launch.template operator()<GEMMConfig_W4A4_FP16, USE_FP4>();
} else if (dtype == Tensor::BF16) {
launch.template operator()<GEMMConfig_W4A4_BF16, USE_FP4>();
} else {
assert(false);
}
});
}
}
......@@ -56,72 +74,92 @@ void gemm_w4a4(
}
}
}
invoke_launch(dtype, [&]<typename Config>() {
dispatchBool(fp4, [&]<bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
wcscales,
out_q,
out_k,
out_v,
attn_tokens
);
});
invoke_launch(dtype, fp4, FasterI2FMode::check(act_unsigned), [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
wcscales,
out_q,
out_k,
out_v,
attn_tokens
);
});
}
void linearattn_vk_mul_q(Tensor q, Tensor vk) {
invoke_launch(q.dtype(), [&]<typename Config>() {
invoke_launch(q.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, false>::linearattn_vk_mul_q(q, vk);
});
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
invoke_launch(input.dtype(), [&]<typename Config>() {
dispatchBool(fp4, [&]<bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
);
});
invoke_launch(input.dtype(), fp4, false, [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
);
});
}
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_act(
input, output, oscales
);
});
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_wgt(
input, output, oscales
);
});
}
bool FasterI2FMode::check(bool act_unsigned) {
auto *prop = getCurrentDeviceProperties();
if (prop->major != 7 || prop->minor != 5) {
return false;
}
if (mode == Always) {
return true;
} else if (mode == Enabled && !act_unsigned) {
return true;
} else {
return false;
}
}
void set_faster_i2f_mode(std::string mode) {
static const std::map<std::string, FasterI2FMode::Mode> mapping = {
{"disabled", FasterI2FMode::Disabled},
{"enabled", FasterI2FMode::Enabled},
{"always", FasterI2FMode::Always},
};
FasterI2FMode::mode = mapping.at(mode);
}
};
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment