Unverified Commit ad8097b9 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Release v0.2.0

Ready to release v0.2.0
parents 804a6d30 998192ca
#include "zgemm.h"
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace nunchaku::kernels {
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}
}
);
checkCUDA(cudaGetLastError());
}
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.scalar_type() == Tensor::FP16);
Tensor output = Tensor::empty_like(input);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.actualM = numTokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}
);
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -448,6 +448,7 @@ public: ...@@ -448,6 +448,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N] // out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue> template<typename Epilogue>
struct gemm_w8a8_kernel { struct gemm_w8a8_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__ __device__
void operator()( void operator()(
const packed_act_t *act, const packed_act_t *act,
......
#pragma once
#include "gemm_base.cuh"
namespace nunchaku::kernels {
template<typename Config>
class Lora;
#ifndef __INTELLISENSE__
template<typename Config>
class Lora : public GEMMBase<Config> {
#else
template<>
class Lora<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
using Config = GEMMConfig_W4A4_FP16;
#endif
public:
IMPORT_GEMM_BASE(Config);
public:
static constexpr int MAX_RANK = 1024;
static constexpr int WARP_R = 16;
// static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = WARP_R / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, MAX_RANK / 16>;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__ __forceinline__
static void load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId];
const int stride_ntile = rank / 16 * WARP_SIZE;
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int roffset = r * WARP_SIZE;
const int noffset = n * stride_ntile;
result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred);
});
});
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static void load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
const float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r];
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add_pred(&ptrlane[offset], val[i].data[j], pred);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
int rank;
scale_t scales;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, const float *act, const packed_fpsum_t *wgt, const scale_t &scales, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act[NUM_STAGES]; // 32
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
int dummy = 0;
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
const bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_act(act, 0, lora_act[k], pred);
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128
auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE {
lora_act16_warp A_fp16;
for (int m = 0; m < LORA_M_TILES; m++) {
for (int r = 0; r < LORA_R_TILES; r++) {
packed_f32psum_t pack = A[m * LORA_R_TILES + r];
#pragma unroll
for (int j = 0; j < 8; j++) {
pack.data[j] *= scales[rtile * LORA_R_TILES + r];
}
A_fp16[m * LORA_R_TILES + r] = packed_fp32_to_fp16(pack);
}
}
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]);
}
}
}
};
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
act += kernels::bit_cast<int>(lora_act[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_act(act, nextk, lora_act[idx], pred);
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
compute(lora_act[k2], lora_wgt[k2], f32psum, k1 + k2);
}
}
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_act[k]) {
#pragma unroll
for (int i = 0; i < 8; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
fpsum = packed_fp32_to_fp16(f32psum);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
apply_lora_up(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.scales,
args.rank,
args.alwaysfalse
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
int rank;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
auto compute = [](lora_wgt_warp W, fpsum_warp fpsum) -> lora_act_warp {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], W[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
}
return lora_act;
};
int dummy = 0;
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
wgt += kernels::bit_cast<int>(lora_wgt[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
if (alwaysfalse) {
dummy = clock();
}
lora_act_warp lora_act = compute(lora_wgt[k2], fpsum);
reduce_lora_act(act, k1 + k2, lora_act, true);
}
}
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
apply_lora_down(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.rank,
args.alwaysfalse
);
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
#pragma once
#include <cstdint>
#include "common.h"
// only supports cuda 12.5+
namespace nunchaku::kernels {
namespace mma_helper {
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
};
__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#endif
return d;
}
template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"C"(mma_helper::f16bf16<is_bf16>::value)
);
#else
static_assert(!is_bf16);
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#endif
return d;
}
template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K),
"C"(AType::value),
"C"(BType::value)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2),
"C"(AType::value),
"C"(BType::value)
);
#endif
return d;
}
}; // namespace nunchaku::kernels
\ No newline at end of file
#pragma once
#include <cstdint>
#include "common.h"
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
namespace nunchaku::kernels {
namespace mma_helper {
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
};
__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#endif
return d;
}
template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 d;
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
return d;
}
#endif
template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#endif
return d;
}
template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) = delete;
template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2)
);
#endif
return d;
}
template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2)
);
#endif
return d;
}
}; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -30,7 +30,11 @@ void gemm_w4a4( ...@@ -30,7 +30,11 @@ void gemm_w4a4(
bool fuse_silu, bool fuse_silu,
bool fp4, bool fp4,
float alpha, float alpha,
Tensor wcscales Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
); );
void linearattn_vk_mul_q(Tensor q, Tensor vk); void linearattn_vk_mul_q(Tensor q, Tensor vk);
...@@ -57,4 +61,19 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -57,4 +61,19 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N] // Tensor wscales // [1, N]
// ); // );
void attention_fp16(
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
);
// EXPERIMENTAL, for sm_75
void set_faster_i2f_mode(std::string mode);
// FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -3,11 +3,12 @@ import os ...@@ -3,11 +3,12 @@ import os
import random import random
import datasets import datasets
import yaml
from PIL import Image from PIL import Image
_CITATION = """\ _CITATION = """\
@misc{li2024playground, @misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation}, title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi}, author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024}, year={2024},
eprint={2402.17245}, eprint={2402.17245},
...@@ -17,7 +18,7 @@ _CITATION = """\ ...@@ -17,7 +18,7 @@ _CITATION = """\
""" """
_DESCRIPTION = """\ _DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality. We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality. The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
""" """
...@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/ ...@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json" META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
CONTROL_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/MJHQ-5000.zip"
class MJHQConfig(datasets.BuilderConfig): class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs): def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
...@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig): ...@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig):
self.return_gt = return_gt self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder): class MJHQ(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")] BUILDER_CONFIGS = [
MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset"),
MJHQConfig(name="MJHQ-control", version=VERSION, description="MJHQ-5K with controls"),
]
DEFAULT_CONFIG_NAME = "MJHQ" DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self): def _info(self):
...@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder): ...@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder):
"image_root": datasets.Value("string"), "image_root": datasets.Value("string"),
"image_path": datasets.Value("string"), "image_path": datasets.Value("string"),
"split": datasets.Value("string"), "split": datasets.Value("string"),
"canny_image_path": datasets.Value("string"),
"cropped_image_path": datasets.Value("string"),
"depth_image_path": datasets.Value("string"),
"mask_image_path": datasets.Value("string"),
} }
) )
return datasets.DatasetInfo( return datasets.DatasetInfo(
...@@ -71,36 +81,75 @@ class DCI(datasets.GeneratorBasedBuilder): ...@@ -71,36 +81,75 @@ class DCI(datasets.GeneratorBasedBuilder):
) )
def _split_generators(self, dl_manager: datasets.download.DownloadManager): def _split_generators(self, dl_manager: datasets.download.DownloadManager):
meta_path = dl_manager.download(META_URL) if self.config.name == "MJHQ":
image_root = dl_manager.download_and_extract(IMAGE_URL) meta_path = dl_manager.download(META_URL)
return [ image_root = dl_manager.download_and_extract(IMAGE_URL)
datasets.SplitGenerator( return [
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root} datasets.SplitGenerator(
), name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
] ),
]
else:
assert self.config.name == "MJHQ-control"
control_root = dl_manager.download_and_extract(CONTROL_URL)
control_root = os.path.join(control_root, "MJHQ-5000")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"meta_path": os.path.join(control_root, "prompts.yaml"), "image_root": control_root},
),
]
def _generate_examples(self, meta_path: str, image_root: str): def _generate_examples(self, meta_path: str, image_root: str):
if self.config.name == "MJHQ":
with open(meta_path, "r") as f: with open(meta_path, "r") as f:
meta = json.load(f) meta = json.load(f)
names = list(meta.keys()) names = list(meta.keys())
if self.config.max_dataset_size > 0: if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names) random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size] names = names[: self.config.max_dataset_size]
names = sorted(names) names = sorted(names)
for i, name in enumerate(names): for i, name in enumerate(names):
category = meta[name]["category"] category = meta[name]["category"]
prompt = meta[name]["prompt"] prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg") image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, { yield i, {
"filename": name, "filename": name,
"category": category, "category": category,
"image": Image.open(image_path) if self.config.return_gt else None, "image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt, "prompt": prompt,
"meta_path": meta_path, "meta_path": meta_path,
"image_root": image_root, "image_root": image_root,
"image_path": image_path, "image_path": image_path,
"split": self.config.name, "split": self.config.name,
} "canny_image_path": None,
"cropped_image_path": None,
"depth_image_path": None,
"mask_image_path": None,
}
else:
assert self.config.name == "MJHQ-control"
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
prompt = meta[name]
yield i, {
"filename": name,
"category": None,
"image": None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": os.path.join(image_root, "images", f"{name}.png"),
"split": self.config.name,
"canny_image_path": os.path.join(image_root, "canny_images", f"{name}.png"),
"cropped_image_path": os.path.join(image_root, "cropped_images", f"{name}.png"),
"depth_image_path": os.path.join(image_root, "depth_images", f"{name}.png"),
"mask_image_path": os.path.join(image_root, "mask_images", f"{name}.png"),
}
...@@ -3,9 +3,16 @@ import random ...@@ -3,9 +3,16 @@ import random
import datasets import datasets
import yaml import yaml
from huggingface_hub import snapshot_download
from nunchaku.utils import fetch_or_download from nunchaku.utils import fetch_or_download
__all__ = ["get_dataset"] __all__ = ["get_dataset", "load_dataset_yaml", "download_hf_dataset"]
def download_hf_dataset(repo_id: str = "mit-han-lab/nunchaku-test", local_dir: str | None = None) -> str:
path = snapshot_download(repo_id=repo_id, repo_type="dataset", local_dir=local_dir)
return path
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict: def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
...@@ -46,10 +53,13 @@ def get_dataset( ...@@ -46,10 +53,13 @@ def get_dataset(
path = os.path.join(prefix, f"{name}") path = os.path.join(prefix, f"{name}")
if name == "MJHQ": if name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs) dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
elif name == "MJHQ-control":
kwargs["name"] = "MJHQ-control"
dataset = datasets.load_dataset(os.path.join(prefix, "MJHQ"), return_gt=return_gt, **kwargs)
else: else:
dataset = datasets.Dataset.from_dict( dataset = datasets.Dataset.from_dict(
load_dataset_yaml( load_dataset_yaml(
fetch_or_download(f"mit-han-lab/nunchaku-test/{name}.yaml", repo_type="dataset"), fetch_or_download(f"mit-han-lab/svdquant-datasets/{name}.yaml", repo_type="dataset"),
max_dataset_size=max_dataset_size, max_dataset_size=max_dataset_size,
repeat=1, repeat=1,
), ),
......
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
@pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(0.12, 1024, 1024, 30, None, 1, 0.26),
(0.12, 512, 2048, 30, "anime", 1, 0.4),
],
)
def test_flux_dev_loras(
cache_threshold: float,
height: int,
width: int,
num_inference_steps: int,
lora_name: str,
lora_strength: float,
expected_lpips: float,
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_names=lora_name,
lora_strengths=lora_strength,
cache_threshold=cache_threshold,
expected_lpips=expected_lpips,
)
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
@pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, 50, "flashattn2", False, 0.226),
(2048, 512, 25, "nunchaku-fp16", False, 0.243),
],
)
def test_flux_dev(
height: int, width: int, num_inference_steps: int, attention_impl: str, cpu_offload: bool, expected_lpips: float
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
height=height,
width=width,
num_inference_steps=num_inference_steps,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, True, 0.178),
(25, "ghibsky", 1, False, 0.164),
(28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.223),
(28, "yarn", 1, False, 0.211),
(25, "haunted_linework", 1, True, 0.317),
],
)
def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name=lora_name,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_names=lora_name,
lora_strengths=lora_strength,
cache_threshold=0,
expected_lpips=expected_lpips,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_hypersd8_1536x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=1536,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="hypersd8",
lora_strengths=0.125,
cache_threshold=0,
expected_lpips=0.291,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_2048x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=2048,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="turbo8",
lora_strengths=1,
cache_threshold=0,
expected_lpips=0.189,
)
# lora composition
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_2048x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="yarn",
height=2048,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["turbo8", "yarn"],
lora_strengths=[1, 1],
cache_threshold=0,
expected_lpips=0.252,
)
# large rank loras
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_1024x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="ghibsky",
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0,
expected_lpips=0.44,
)
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit",
[
(False, False, 17),
(False, True, 13),
(True, False, 12),
(True, True, 6),
],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
precision = get_precision()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
)
memory = torch.cuda.max_memory_reserved(0) / 1024**3
assert memory < memory_limit
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, "flashattn2", False, 0.250),
(1024, 1024, "nunchaku-fp16", False, 0.255),
(1024, 1024, "flashattn2", True, 0.250),
(1920, 1080, "nunchaku-fp16", False, 0.253),
(2048, 2048, "flashattn2", True, 0.274),
],
)
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
precision=get_precision(),
height=height,
width=width,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
import pytest
import torch import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku.utils import get_precision, is_turing
from .utils import run_test
def test_flux_dev_canny(): @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev") def test_flux_canny_dev():
pipe = FluxControlPipeline.from_pretrained( run_test(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16 precision=get_precision(),
).to("cuda") model_name="flux.1-canny-dev",
dataset_name="MJHQ-control",
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." task="canny",
control_image = load_image( dtype=torch.bfloat16,
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.103 if get_precision() == "int4" else 0.164,
) )
processor = CannyDetector()
control_image = processor(
control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
)
image = pipe( @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0 def test_flux_depth_dev():
).images[0] run_test(
image.save("flux.1-canny-dev.png") precision=get_precision(),
model_name="flux.1-depth-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.170 if get_precision() == "int4" else 0.120,
)
def test_flux_dev_depth(): @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev") def test_flux_fill_dev():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.045,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
control_image = load_image( def test_flux_dev_canny_lora():
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="canny",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
lora_names="canny",
lora_strengths=0.85,
cache_threshold=0,
expected_lpips=0.103,
) )
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=30, guidance_scale=10.0
).images[0]
image.save("flux.1-depth-dev.png")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_depth_lora():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="depth",
lora_strengths=0.85,
expected_lpips=0.163,
)
def test_flux_dev_fill():
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev") @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
pipe = FluxFillPipeline.from_pretrained( def test_flux_fill_dev_turbo():
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16 run_test(
).to("cuda") precision=get_precision(),
image = pipe( model_name="flux.1-fill-dev",
prompt="A wooden basket of a cat.", dataset_name="MJHQ-control",
image=image, task="fill",
mask_image=mask, dtype=torch.bfloat16,
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=8,
guidance_scale=30, guidance_scale=30,
num_inference_steps=50, attention_impl="nunchaku-fp16",
max_sequence_length=512, cpu_offload=False,
).images[0] cache_threshold=0,
image.save("flux.1-fill-dev.png") lora_names="turbo8",
lora_strengths=1,
expected_lpips=0.048,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_redux(): def test_flux_dev_redux():
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( run_test(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 precision=get_precision(),
).to("cuda") model_name="flux.1-dev",
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") dataset_name="MJHQ-control",
pipe = FluxPipeline.from_pretrained( task="redux",
"black-forest-labs/FLUX.1-dev", dtype=torch.bfloat16,
text_encoder=None, height=1024,
text_encoder_2=None, width=1024,
transformer=transformer, num_inference_steps=50,
torch_dtype=torch.bfloat16, guidance_scale=2.5,
).to("cuda") attention_impl="nunchaku-fp16",
cpu_offload=False,
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") cache_threshold=0,
pipe_prior_output = pipe_prior_redux(image) expected_lpips=0.198 if get_precision() == "int4" else 0.55, # redux seems to generate different images on 5090
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images )
images[0].save("flux.1-redux-dev.png")
import pytest
from .utils import run_test
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[(1024, 1024, "flashattn2", False, 0.25), (2048, 512, "nunchaku-fp16", False, 0.25)],
)
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
precision=get_precision(),
model_name="shuttle-jaguar",
height=height,
width=width,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
import os
import tempfile
import pytest
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from safetensors.torch import save_file
from tqdm import tqdm
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux import comfyui2diffusers, convert_to_nunchaku_flux_lowrank_dict, detect_format, xlab2diffusers
from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
def run_pipeline(dataset, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
seed = hash_str_to_int(filename)
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
@pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips",
[
("int4", 1024, 1024, 4, 0, False, False, 16, 0.258),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1920, 1080, 4, 0, False, False, 16, 0.258),
("int4", 600, 800, 4, 0, False, False, 16, 0.29),
],
)
def test_flux_schnell(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
max_dataset_size: int,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "schnell", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
}
def run_test_flux_dev(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
lora_name: str | None,
lora_scale: float,
max_dataset_size: int,
expected_lpips: float,
):
save_root = os.path.join(
"results",
"dev",
f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
+ ("-qencoder" if use_qencoder else "")
+ (f"-{lora_name}_{lora_scale:.1f}" if lora_name else ""),
)
dataset = get_dataset(
name="MJHQ" if lora_name in [None, "hypersd8"] else lora_name, max_dataset_size=max_dataset_size
)
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
if lora_name is not None:
pipeline.load_lora_weights(
os.path.dirname(LORA_PATH_MAP[lora_name]),
weight_name=os.path.basename(LORA_PATH_MAP[lora_name]),
adapter_name="lora",
)
for n, m in pipeline.transformer.named_modules():
if isinstance(m, lora.LoraLayer):
for name in m.scaling.keys():
m.scaling[name] = lora_scale
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}")
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-dev", precision="fp4", offload=cpu_offload
)
if lora_name is not None:
lora_path = LORA_PATH_MAP[lora_name]
lora_format = detect_format(lora_path)
if lora_format != "svdquant":
if lora_format == "comfyui":
input_lora = comfyui2diffusers(lora_path)
elif lora_format == "xlab":
input_lora = xlab2diffusers(lora_path)
elif lora_format == "diffusers":
input_lora = lora_path
else:
raise ValueError(f"Invalid LoRA format {lora_format}.")
state_dict = convert_to_nunchaku_flux_lowrank_dict(
"mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors", input_lora
)
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=True) as tmp_file:
save_file(state_dict, tmp_file.name)
transformer.update_lora_params(tmp_file.name)
else:
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_scale)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_flux_dev_base(cpu_offload: bool):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_name=None,
lora_scale=0,
max_dataset_size=8,
expected_lpips=0.16,
)
def test_flux_dev_qencoder_800x600():
run_test_flux_dev(
precision="int4",
height=800,
width=600,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=True,
cpu_offload=False,
lora_name=None,
lora_scale=0,
max_dataset_size=8,
expected_lpips=0.36,
)
def test_flux_dev_hypersd8_1080x1920():
run_test_flux_dev(
precision="int4",
height=1080,
width=1920,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_name="hypersd8",
lora_scale=0.125,
max_dataset_size=8,
expected_lpips=0.44,
)
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, False, 0.16),
(25, "ghibsky", 1, False, 0.16),
(28, "anime", 1, False, 0.27),
(24, "sketch", 1, False, 0.35),
(28, "yarn", 1, False, 0.22),
(25, "haunted_linework", 1, False, 0.34),
],
)
def test_flux_dev_loras(num_inference_steps, lora_name, lora_scale, cpu_offload, expected_lpips):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_name=lora_name,
lora_scale=lora_scale,
max_dataset_size=8,
expected_lpips=expected_lpips,
)
@pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit",
[
(False, False, 17),
(False, True, 13),
(True, False, 12),
(True, True, 6),
],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.reset_peak_memory_stats()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
)
memory = torch.cuda.max_memory_reserved(0) / 1024**3
assert memory < memory_limit
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
import pytest
from nunchaku.utils import get_precision
from .utils import run_test
@pytest.mark.skipif(get_precision() == "fp4", reason="Blackwell GPUs. Skip tests for Turing.")
@pytest.mark.parametrize(
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips",
[
(1024, 1024, 50, True, None, 0.253),
(1024, 1024, 50, True, "enabled", 0.258),
(1024, 1024, 50, True, "always", 0.257),
],
)
def test_flux_dev(
height: int, width: int, num_inference_steps: int, cpu_offload: bool, i2f_mode: str | None, expected_lpips: float
):
run_test(
precision=get_precision(),
dtype="fp16",
model_name="flux.1-dev",
height=height,
width=width,
num_inference_steps=num_inference_steps,
attention_impl="nunchaku-fp16",
cpu_offload=cpu_offload,
i2f_mode=i2f_mode,
expected_lpips=expected_lpips,
)
import os
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from tqdm import tqdm
import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from ..data import download_hf_dataset, get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
ORIGINAL_REPO_MAP = {
"flux.1-schnell": "black-forest-labs/FLUX.1-schnell",
"flux.1-dev": "black-forest-labs/FLUX.1-dev",
"shuttle-jaguar": "shuttleai/shuttle-jaguar",
"flux.1-canny-dev": "black-forest-labs/FLUX.1-Canny-dev",
"flux.1-depth-dev": "black-forest-labs/FLUX.1-Depth-dev",
"flux.1-fill-dev": "black-forest-labs/FLUX.1-Fill-dev",
}
NUNCHAKU_REPO_PATTERN_MAP = {
"flux.1-schnell": "mit-han-lab/svdq-{precision}-flux.1-schnell",
"flux.1-dev": "mit-han-lab/svdq-{precision}-flux.1-dev",
"shuttle-jaguar": "mit-han-lab/svdq-{precision}-shuttle-jaguar",
"flux.1-canny-dev": "mit-han-lab/svdq-{precision}-flux.1-canny-dev",
"flux.1-depth-dev": "mit-han-lab/svdq-{precision}-flux.1-depth-dev",
"flux.1-fill-dev": "mit-han-lab/svdq-{precision}-flux.1-fill-dev",
}
LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
"turbo8": "alimama-creative/FLUX.1-Turbo-Alpha/diffusion_pytorch_model.safetensors",
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
"canny": "black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors",
"depth": "black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors",
}
def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
if task == "canny":
processor = CannyDetector()
elif task == "depth":
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
elif task == "redux":
processor = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
else:
assert task in ["t2i", "fill"]
processor = None
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
_forward_kwargs = {k: v for k, v in forward_kwargs.items()}
if task == "canny":
assert forward_kwargs.get("height", 1024) == 1024
assert forward_kwargs.get("width", 1024) == 1024
control_image = load_image(row["canny_image_path"])
control_image = processor(
control_image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024,
)
_forward_kwargs["control_image"] = control_image
elif task == "depth":
control_image = load_image(row["depth_image_path"])
control_image = processor(control_image)[0].convert("RGB")
_forward_kwargs["control_image"] = control_image
elif task == "fill":
image = load_image(row["image_path"])
mask_image = load_image(row["mask_image_path"])
_forward_kwargs["image"] = image
_forward_kwargs["mask_image"] = mask_image
elif task == "redux":
image = load_image(row["image_path"])
_forward_kwargs.update(processor(image))
seed = hash_str_to_int(filename)
if task == "redux":
image = pipeline(generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
else:
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
torch.cuda.empty_cache()
def run_test(
precision: str = "int4",
model_name: str = "flux.1-schnell",
dataset_name: str = "MJHQ",
task: str = "t2i",
dtype: str | torch.dtype = torch.bfloat16, # the full precision dtype
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 3.5,
use_qencoder: bool = False,
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False,
cache_threshold: float = 0,
lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 20,
i2f_mode: str | None = None,
expected_lpips: float = 0.5,
):
if isinstance(dtype, str):
dtype_str = dtype
if dtype == "bf16":
dtype = torch.bfloat16
else:
assert dtype == "fp16"
dtype = torch.float16
else:
if dtype == torch.bfloat16:
dtype_str = "bf16"
else:
assert dtype == torch.float16
dtype_str = "fp16"
dataset = get_dataset(name=dataset_name, max_dataset_size=max_dataset_size)
model_id_16bit = ORIGINAL_REPO_MAP[model_name]
folder_name = f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
if lora_names is None:
lora_names = []
elif isinstance(lora_names, str):
lora_names = [lora_names]
if len(lora_names) > 0:
if isinstance(lora_strengths, (int, float)):
lora_strengths = [lora_strengths]
assert len(lora_names) == len(lora_strengths)
for lora_name, lora_strength in zip(lora_names, lora_strengths):
folder_name += f"-{lora_name}_{lora_strength}"
if not os.path.exists(os.path.join("test_results", "ref")):
ref_root = download_hf_dataset(local_dir=os.path.join("test_results", "ref"))
else:
ref_root = os.path.join("test_results", "ref")
save_dir_16bit = os.path.join(ref_root, dtype_str, model_name, folder_name)
if task in ["t2i", "redux"]:
pipeline_cls = FluxPipeline
elif task in ["canny", "depth"]:
pipeline_cls = FluxControlPipeline
elif task == "fill":
pipeline_cls = FluxFillPipeline
else:
raise NotImplementedError(f"Unknown task {task}!")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {}
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
pipeline = pipeline.to("cuda")
if len(lora_names) > 0:
for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)):
lora_path = LORA_PATH_MAP[lora_name]
pipeline.load_lora_weights(
os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name=f"lora_{i}"
)
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
precision_str = precision
if use_qencoder:
precision_str += "-qe"
if attention_impl == "flashattn2":
precision_str += "-fa2"
else:
assert attention_impl == "nunchaku-fp16"
precision_str += "-nfp16"
if cpu_offload:
precision_str += "-co"
if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}"
if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}"
save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
model_id_4bit = NUNCHAKU_REPO_PATTERN_MAP[model_name].format(precision=precision)
if i2f_mode is not None:
nunchaku._C.utils.set_faster_i2f_mode(i2f_mode)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
model_id_4bit, offload=cpu_offload, torch_dtype=dtype
)
transformer.set_attention_impl(attention_impl)
if len(lora_names) > 0:
if len(lora_names) == 1: # directly load the lora
lora_path = LORA_PATH_MAP[lora_names[0]]
lora_strength = lora_strengths[0]
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_strength)
else:
composed_lora = compose_lora(
[
(LORA_PATH_MAP[lora_name], lora_strength)
for lora_name, lora_strength in zip(lora_names, lora_strengths)
]
)
transformer.update_lora_params(composed_lora)
pipeline_init_kwargs["transformer"] = transformer
if task == "redux":
pipeline_init_kwargs.update({"text_encoder": None, "text_encoder_2": None})
elif use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del transformer
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
# additional requirements for testing
pytest pytest
datasets datasets
torchmetrics torchmetrics
......
import pytest
import torch import torch
from diffusers import SanaPAGPipeline, SanaPipeline from diffusers import SanaPAGPipeline, SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana(): def test_sana():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
...@@ -28,6 +31,7 @@ def test_sana(): ...@@ -28,6 +31,7 @@ def test_sana():
image.save("sana_1600m.png") image.save("sana_1600m.png")
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana_pag(): def test_sana_pag():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8) transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained( pipe = SanaPAGPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment