"...pipelines/qwenimage/pipeline_qwenimage_controlnet.py" did not exist on "4e74206b0c443f9d272401f397d781d9d0630073"
Commit e9ad0535 authored by muyangli's avatar muyangli
Browse files

[major] support SANA

parent 9eb2cee0
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16>;
};
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>;
};
\ No newline at end of file
#include "gemm_w4a4_launch.cuh"
namespace nunchaku::kernels {
#ifndef __INTELLISENSE__
template<typename Config>
void GEMM_W4A4_Launch<Config>::gemm_w4a4(
#else
template<>
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
#endif
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
) {
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1] * 2;
assert(K == wgt.shape[1] * 2);
int actualM = 0;
int actualN = 0;
if (out.valid()) {
actualM = out.numel() / out.shape[-1];
actualN = out.shape[-1];
assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
}
spdlog::trace("gemm_w4a4: M={} N={} K={}", M, N, K);
spdlog::trace("act at {}", act.data_ptr());
spdlog::trace("wgt at {}", wgt.data_ptr());
spdlog::trace("ascales at {}", ascales.data_ptr());
spdlog::trace("wscales at {}", wscales.data_ptr());
if (bias.valid()) {
spdlog::trace("bias at {}", bias.data_ptr());
}
int shmem = 0;
auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
using kernel = typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>;
auto func = invoke_kernel<kernel,
const packed_act_t *,
const packed_wgt_t *,
const packed_ascale_t *,
const packed_wscale_t *,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
};
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
if (!bias.valid()) {
return launch.template operator()<NextEpilogue>(nextArgs);
}
assert(bias.numel() == N);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = typename GEMM::EpilogueCombination<typename GEMM::EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
typename GEMM::EpilogueBias::Arguments{
.bias = bias.data_ptr<packed_wscale_t>(),
},
nextArgs,
{}
});
};
// auto launch_bias = launch;
auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) {
assert(lora_up.valid() == lora_act_in.valid());
assert(lora_down.valid() == lora_act_out.valid());
if (!lora_up.valid()) {
assert(!lora_down.valid());
return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
}
const int rank_up = lora_up.shape[1];
assert(lora_up.shape[0] == N);
// assert(lora_up.shape[1] == Lora::LORA_RANK);
assert(lora_act_in.shape[0] == M);
assert(lora_act_in.shape[1] == rank_up);
dispatchVal(rank_up, LoraRanks(), [&]<int RANK_UP>() {
using LoraUp = typename GEMM::Lora<RANK_UP>;
using scale_t = typename LoraUp::scale_t;
scale_t scales;
if constexpr (scales.size() > 0) {
assert(lora_scales.size() >= scales.size());
for (size_t i = 0; i < scales.size(); i++) {
scales[i] = lora_scales[i];
}
}
if (!lora_down.valid()) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.scales = scales,
},
midArgs,
nextArgs,
{}
});
}
const int rank_down = lora_down.shape[1];
assert(rank_down == rank_up);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank_down);
lora_act_out.zero_();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.scales = scales,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
},
nextArgs,
{}
});
// });
});
};
if (qout.valid() && oscales.valid()) {
// dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {
static constexpr float SHIFT_GELU = 0.171875f;
constexpr bool USE_UNSIGNED = true;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<packed_ascale_t>(),
.shift_value = SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
} else if (out_linearattn.valid()) {
assert(out_vk.valid());
using Epilogue = typename GEMM::EpilogueLiteLA;
assert(out_vk.dtype() == Tensor::FP32);
assert(out_vk.ndims() == 4);
assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1];
assert(isTypeMatch<half_t>(out_linearattn.dtype()));
assert(out_linearattn.ndims() == 3);
assert(out_linearattn.shape[0] == batch_size);
assert(out_linearattn.shape[2] * 3 == N);
int num_tokens = out_linearattn.shape[1];
assert(num_tokens % GEMM::BLOCK_M == 0);
int num_blocks_per_batch = ceilDiv(num_tokens, GEMM::BLOCK_M);
shmem = std::max(shmem, Epilogue::SHMEM_SIZE);
out_vk.zero_();
launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(typename Epilogue::Arguments{
.out_q = out_linearattn.data_ptr<half_t>(),
.out_vk = out_vk.data_ptr<float>(),
.num_blocks_per_batch = num_blocks_per_batch,
.actualM = M,
}, {});
} else if (rotary_emb.valid()) {
assert(norm_q.valid());
assert(norm_k.valid());
// assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
.pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
.rotary_emb = rotary_emb.data_ptr<float>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
.epsilon = 1e-6,
}, {});
} else if (out.valid()) {
using Epilogue = typename GEMM::EpilogueDefault;
typename Epilogue::Arguments args{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
};
if (fuse_silu) {
launch_lora.template operator()<Epilogue, typename GEMM::EpilogueSilu>(args, {});
} else {
launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(args, {});
}
} else {
assert(false);
}
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
using Epilogue = typename GEMM::EpilogueLiteLA;
int batch_size = vk.shape[0];
int num_heads = vk.shape[1];
int num_tokens = q.shape[1];
assert(isTypeMatch<half_t>(q.scalar_type()));
assert(vk.scalar_type() == Tensor::FP32);
int BLOCK_SIZE;
if (num_tokens % 256 == 0) {
BLOCK_SIZE = 256;
} else {
BLOCK_SIZE = 128;
}
invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE>>>(
q.data_ptr<half_t>(),
vk.data_ptr<float>(),
1e-6f,
num_tokens
);
checkCUDA(cudaGetLastError());
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) {
const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1];
const int M = ceilDiv(actualM, GEMM::BLOCK_M) * GEMM::BLOCK_M;
const int N = ceilDiv(actualN / (fuse_glu ? 2 : 1), GEMM::BLOCK_N) * GEMM::BLOCK_N;
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == N / 2);
// assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<half_t>(oscales.dtype()));
assert(oscales.numel() == M * N / GEMM::WARP_K);
const int rank = lora_down.shape[1];
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank);
lora_act_out.zero_();
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<packed_ascale_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
}
);
checkCUDA(cudaGetLastError());
});
});
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
int M = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == K / 2);
// assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<half_t>(oscales.dtype()));
assert(oscales.numel() == M * K / GEMM::WARP_K);
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE>>>(
input.data_ptr<half_t>(),
output.data_ptr<packed_act_t>(),
oscales.data_ptr<packed_ascale_t>(),
K
);
checkCUDA(cudaGetLastError());
}
template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
int N = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.ndims() == 2);
assert(output.shape[0] == N);
assert(output.shape[1] == K / 2);
assert(isTypeMatch<half_t>(oscales.dtype()));
// assert(oscales.dtype() == Tensor::FP16);
assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE>>>(
input.data_ptr<half_t>(),
output.data_ptr<packed_wgt_t>(),
oscales.data_ptr<packed_wscale_t>(),
K
);
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
#include "zgemm.h"
#include "gemm_w8a8.cuh"
namespace nunchaku::kernels {
void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu) {
using GEMM = GEMM_W8A8;
int M = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == fuse_glu ? K / 2 : K);
assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
assert(oscales.numel() == M * 1);
auto launch = [&]<bool FUSE_GLU>() {
using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;
assert(kernel::check(M, K));
dim3 grid = kernel::gridSize(M, K);
dim3 block = kernel::blockSize(M, K);
auto func = invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
func<<<grid, block, kernel::smemSize(M, K)>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K,
false
);
checkCUDA(cudaGetLastError());
};
if (fuse_glu) {
launch.template operator()<true>();
} else {
launch.template operator()<false>();
}
}
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales, // [1, N]
Tensor bias
)
{
using GEMM = GEMM_W8A8;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
int actualM = 0;
int actualN = 0;
if (out.valid()) {
actualM = out.numel() / out.shape[-1];
actualN = out.shape[-1];
assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
}
auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M, N, K, args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
};
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
if (!bias.valid()) {
return launch.template operator()<NextEpilogue>(nextArgs);
}
assert(bias.numel() == N);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
GEMM::EpilogueBias::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(),
},
nextArgs,
{}
});
};
launch_bias.template operator()<GEMM::EpilogueDefault>(GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<GEMM::half_t>(),
.actualM = actualM,
.actualN = actualN,
});
}
#if 0
void gemm_w8a8_fuse_litela(
Tensor act, // [B, (M), K]
Tensor wgt, // [N, K]
Tensor out_q, // [B, (M), N / 3]
Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
) {
using GEMM = GEMM_W8A8;
using Epilogue = GEMM::EpilogueLiteLA;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
assert(out_vk.ndims() == 4);
assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1];
assert(M % batch_size == 0);
int batch_m = M / batch_size;
Epilogue::Arguments epilogueArgs;
epilogueArgs.batch_m = act.shape[1];
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
// GEMM::half_t *,
int, int, int,
Epilogue::Arguments,
bool,
bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// nullptr,
M, N, K, epilogueArgs,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(),
1e-6f
);
checkCUDA(cudaGetLastError());
}
#endif
}; // namespace nunchaku::kernels
\ No newline at end of file
#pragma once
#include "gemm_base.cuh"
namespace nunchaku::kernels {
class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public:
using psum_warp = std::array<packed_psum_t, WARP_M_TILES * WARP_N_TILES>;
__device__ __forceinline__
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
// packed_psum_t psum;
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
return psum;
}
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, psum_warp &psum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
psum[i * WARP_N_TILES + j] = mma(A[i], W[j], psum[i * WARP_N_TILES + j]);
}
}
}
/**
* each warp quantizes a INSN_M * INSN_K (16 * 32) matrix
* input is per-warp (in global memory / shared memory)
* oscales is per-warp (in shared memory)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
*/
template<bool input_shmem = false>
__device__ __forceinline__
static void quantize_w8a8_warp(const half_t *input, const half_t *oscales, int stride, packed_act_t &output, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 8;
// constexpr int QUANTIZE_BITMASK = 0xff;
// constexpr int QVALUE_MAX = 128; // 4 bit => [-128, 127]
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 4 for 8bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS];
// load
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load<input_shmem>(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
}
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
const int row = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
const int col = laneId % NUM_PACKS_PER_ROW;
float rscale = cuda_frcp(float(oscales[row]));
uint32_t qpack = 0;
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
float2 fval = half22float2(half2_t(packs[i][j], packs[i][j + 1])) * float2(rscale, rscale);
qpack |= quantize_float2<QUANTIZE_BITWIDTH, false>(fval) << (j * QUANTIZE_BITWIDTH);
}
mat[row][col] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
__syncwarp();
}
/**
* each warp finds absmax from a row
*/
template<bool fuse_glu = false>
__device__ __forceinline__
static half_t findmax_warp(const half_t *input, half_t *output_shmem, int K, bool alwaysfalse) {
const int laneId = threadIdx.x % WARP_SIZE;
using packed_input = std::array<half2_t, 4>;
using packed_gated_input = std::array<half_t, 4>;
constexpr int PACK_SIZE = sizeof(packed_input) / sizeof(half_t);
constexpr int NUM_STAGES = 2;
half2_t maxvalue2 = { 0, 0 };
packed_input pack[NUM_STAGES];
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
const int idx = k * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
if (idx < K) {
pack[k] = load(reinterpret_cast<const packed_input *>(&input[idx]));
} else {
pack[k].fill(half2_t(0, 0));
}
}
// int dummy = 0;
// FIXME: pipeline does not work
// TODO: store quantized data to shmem (instead of half)
for (int k1 = 0; k1 < ceilDiv(K, PACK_SIZE * WARP_SIZE); k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
const int nextidx = (k1 + k2 + NUM_STAGES - 1) * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
const int nextk2 = (k2 + NUM_STAGES - 1) % NUM_STAGES;
if (nextidx < K) {
pack[nextk2] = load(reinterpret_cast<const packed_input *>(&input[nextidx]));
} else {
pack[nextk2].fill(half2_t(0, 0));
}
packed_input &p = pack[k2];
if constexpr (fuse_glu) {
packed_gated_input gated;
#pragma unroll
for (int j = 0; j < p.size(); j++) {
gated[j] = p[j].x * gelu_half(p[j].y);
p[j].x = gated[j];
p[j].y = 0;
}
int idx = (k1 + k2) * PACK_SIZE / 2 * WARP_SIZE + laneId * PACK_SIZE / 2;
if (idx < K) {
store<true>(reinterpret_cast<packed_gated_input *>(&output_shmem[idx]), gated);
}
}
#pragma unroll
for (int j = 0; j < p.size(); j++) {
maxvalue2 = __hmax2(maxvalue2, __habs2(p[j]));
}
}
}
// unused_var(dummy, alwaysfalse);
#pragma unroll
for (int mask = 32 / 2; mask > 0; mask /= 2) {
maxvalue2 = __hmax2(maxvalue2, __shfl_xor_sync(~0, maxvalue2, mask));
}
return __hmax(maxvalue2.x, maxvalue2.y);
}
// each thread block quantize WARP_M * K tile (32 * K)
template<bool fuse_glu>
struct quantize_w8a8_act_kernel {
static constexpr bool check(int M, int K) {
const int K2 = fuse_glu ? K / 2 : K;
return M % WARP_M == 0 && K2 % WARP_K == 0;
}
static constexpr dim3 gridSize(int M, int K) {
return dim3(M / WARP_M);
}
static constexpr dim3 blockSize(int M, int K) {
return dim3(NUM_WARPS * 32);
}
static constexpr size_t smemSize(int M, int K) {
if constexpr (!fuse_glu) {
return 0;
}
const int K2 = fuse_glu ? K / 2 : K;
return INSN_M * K2 * sizeof(half_t);
}
__device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K, bool alwaysfalse) {
// for quantize kernel
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
const int numWarps = blockDim.x / WARP_SIZE;
// for GEMM kernel
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M);
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[NUM_WARPS][512];
const int K2 = fuse_glu ? K / 2 : K;
// INSN_M * K2
extern __shared__ uint8_t smem[];
half_t *shmem = reinterpret_cast<half_t *>(smem);
for (int tileM = 0; tileM < WARP_M_TILES; tileM++) {
for (int i = warpId; i < INSN_M; i += numWarps) {
const int rowLocal = tileM * INSN_M + i;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
half_t maxv = findmax_warp<fuse_glu>(input + rowGlobal * K, shmem + i * K2, K, alwaysfalse);
oscale_shmem[rowLocal] = maxv / half_t(127);
// rscale_shmem[rowLocal] = half_t(127) / maxv;
// maxv_shmem[rowLocal] = maxv;
}
__syncthreads();
for (int bk = warpId; bk < K2 / WARP_K; bk += numWarps) {
const int rowLocal = tileM * INSN_M;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
const int col = bk * WARP_K;
packed_act_t tmpout;
if constexpr (fuse_glu) {
quantize_w8a8_warp<true>(
shmem + col,
oscale_shmem + rowLocal,
K2,
tmpout,
&tmp_shmem[warpId]
);
} else {
quantize_w8a8_warp<false>(
input + rowGlobal * K + col,
oscale_shmem + rowLocal,
K,
tmpout,
&tmp_shmem[warpId]
);
}
store(&output[(((bm * K2 / WARP_K + bk) * NUM_WARPS + gemmWarpId) * WARP_M_TILES + tileM) * WARP_SIZE + laneId], tmpout);
}
__syncthreads();
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales(oscale_shmem, &oscales[(bm * NUM_WARPS + gemmWarpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
__device__ __forceinline__
static gated_fpsum_warp apply_glu(fpsum_warp fpsum) {
gated_fpsum_warp result;
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
for (int k = 0; k < 4; k++) {
half_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst = src.x * gelu_half(src.y);
}
}
}
return result;
}
static constexpr int unpack_gated_fpsum_shmem_size = INSN_M * (WARP_N / 2 + 8) * sizeof(half_t);
__device__ __forceinline__
static void unpack_gated_fpsum(gated_fpsum_warp fpsum, half_t *output, int stride, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int PACK_SIZE = WARP_N / 2 / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>;
// +8 to prevent bank conflicts
using matrix_t = half_t[INSN_M][WARP_N / 2 + 8];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
packed_gated_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 + j * INSN_N / 2;
*reinterpret_cast<half_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half_t *>(&mat[row][col + 4]) = fsum.data[2];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[1];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[3];
}
__syncwarp();
for (int row = 0; row < INSN_M; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
}
}
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template<typename Epilogue>
__device__ __forceinline__
static void gemm_w8a8_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogeParams,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale; // 1
wscale_warp wscale; // 2
psum_warp psum; // 128
for (auto &pack : psum) {
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
}
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
load_ascale(ascales, 0, M, ascale, true);
load_wscale(wscales, 0, N, wscale, true);
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
// load_wscale<false>(wscales, wscale[idx], pred);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
// if (alwaysfalse) {
// dummy = clock();
// }
compute(A[k2], W[k2], psum);
// if (alwaysfalse) {
// dummy = clock();
// }
// asm volatile ("membar.cta;");
}
}
unused_var(dummy, alwaysfalse);
f32psum_warp f32psum;
#pragma unroll
for (int i = 0; i < f32psum.size(); i++) {
#pragma unroll
for (int j = 0; j < 8; j++) {
f32psum[i].data[j] = 0;
}
}
apply_scales([&](int i, int j) {
return psum[i * WARP_N_TILES + j];
}, ascale, wscale, f32psum);
fpsum_warp fpsum = packed_fp32_to_fp16(f32psum);
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x % 32 == 0) {
// printf("warpId = %d fpsum = %f\n", warpId, (float)fpsum[0].data[0].x);
// }
Epilogue()(binfo, fpsum, M, N, K, epilogeParams);
}
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue>
struct gemm_w8a8_kernel {
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
if (swapBlockXY) {
std::swap(binfo.bm, binfo.bn);
std::swap(binfo.numBlocksM, binfo.numBlocksN);
}
const int bm = binfo.bm;
const int bn = binfo.bn;
gemm_w8a8_block<Epilogue>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (1) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES, // only 1 group in W8A8
wscales + bn * (1) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M, N, K,
epilogueArgs,
alwaysfalse
);
}
};
#if 0
struct EpilogueGLU {
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
gated_fpsum_warp gated_fpsum = apply_glu(fpsum);
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_gated_fpsum_shmem_size, 128) * 128];
unpack_gated_fpsum(gated_fpsum, out + warpId * WARP_M * N / 2, N / 2, shmem[warpId]);
}
};
#endif
};
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -3,6 +3,8 @@
#include "common.h"
#include "Tensor.h"
namespace nunchaku::kernels {
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
......@@ -21,11 +23,15 @@ void gemm_w4a4(
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales // [R / 16]
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
);
void linearattn_vk_mul_q(Tensor q, Tensor vk);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {});
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false);
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
......@@ -33,16 +39,19 @@ void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
Tensor wscales, // [1, N]
Tensor bias // packed ws [N]
);
void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu);
void gemm_w8a8_fuse_litela(
Tensor act, // [B, (M), K]
Tensor wgt, // [N, K]
Tensor out_q, // [B, (M), N / 3]
Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
);
\ No newline at end of file
// void gemm_w8a8_fuse_litela(
// Tensor act, // [B, (M), K]
// Tensor wgt, // [N, K]
// Tensor out_q, // [B, (M), N / 3]
// Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
// Tensor ascales, // [1, M]
// Tensor wscales // [1, N]
// );
}; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment