Unverified Commit ad8097b9 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Release v0.2.0

Ready to release v0.2.0
parents 804a6d30 998192ca
#include "zgemm.h"
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace nunchaku::kernels {
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}
}
);
checkCUDA(cudaGetLastError());
}
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
assert(input.ndims() == 2);
const int M = input.shape[0];
const int N = input.shape[1];
assert(input.scalar_type() == Tensor::FP16);
Tensor output = Tensor::empty_like(input);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
assert(N % GEMM::BLOCK_N == 0);
using kernel = typename GEMM::test_epilogue_kernel<Epilogue>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.actualM = numTokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}
);
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -448,6 +448,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue>
struct gemm_w8a8_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__
void operator()(
const packed_act_t *act,
......
#pragma once
#include "gemm_base.cuh"
namespace nunchaku::kernels {
template<typename Config>
class Lora;
#ifndef __INTELLISENSE__
template<typename Config>
class Lora : public GEMMBase<Config> {
#else
template<>
class Lora<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
using Config = GEMMConfig_W4A4_FP16;
#endif
public:
IMPORT_GEMM_BASE(Config);
public:
static constexpr int MAX_RANK = 1024;
static constexpr int WARP_R = 16;
// static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = WARP_R / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, MAX_RANK / 16>;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__ __forceinline__
static void load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId];
const int stride_ntile = rank / 16 * WARP_SIZE;
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int roffset = r * WARP_SIZE;
const int noffset = n * stride_ntile;
result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred);
});
});
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static void load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
const float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r];
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add_pred(&ptrlane[offset], val[i].data[j], pred);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
int rank;
scale_t scales;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, const float *act, const packed_fpsum_t *wgt, const scale_t &scales, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act[NUM_STAGES]; // 32
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
int dummy = 0;
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
const bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_act(act, 0, lora_act[k], pred);
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128
auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE {
lora_act16_warp A_fp16;
for (int m = 0; m < LORA_M_TILES; m++) {
for (int r = 0; r < LORA_R_TILES; r++) {
packed_f32psum_t pack = A[m * LORA_R_TILES + r];
#pragma unroll
for (int j = 0; j < 8; j++) {
pack.data[j] *= scales[rtile * LORA_R_TILES + r];
}
A_fp16[m * LORA_R_TILES + r] = packed_fp32_to_fp16(pack);
}
}
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]);
}
}
}
};
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
act += kernels::bit_cast<int>(lora_act[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_act(act, nextk, lora_act[idx], pred);
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
compute(lora_act[k2], lora_wgt[k2], f32psum, k1 + k2);
}
}
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_act[k]) {
#pragma unroll
for (int i = 0; i < 8; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
fpsum = packed_fp32_to_fp16(f32psum);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
apply_lora_up(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.scales,
args.rank,
args.alwaysfalse
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
int rank;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
auto compute = [](lora_wgt_warp W, fpsum_warp fpsum) -> lora_act_warp {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], W[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
}
return lora_act;
};
int dummy = 0;
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
wgt += kernels::bit_cast<int>(lora_wgt[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
if (alwaysfalse) {
dummy = clock();
}
lora_act_warp lora_act = compute(lora_wgt[k2], fpsum);
reduce_lora_act(act, k1 + k2, lora_act, true);
}
}
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
apply_lora_down(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.rank,
args.alwaysfalse
);
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
#pragma once
#include <cstdint>
#include "common.h"
// only supports cuda 12.5+
namespace nunchaku::kernels {
namespace mma_helper {
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
};
__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#endif
return d;
}
template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"C"(mma_helper::f16bf16<is_bf16>::value)
);
#else
static_assert(!is_bf16);
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#endif
return d;
}
template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K),
"C"(AType::value),
"C"(BType::value)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2),
"C"(AType::value),
"C"(BType::value)
);
#endif
return d;
}
}; // namespace nunchaku::kernels
\ No newline at end of file
#pragma once
#include <cstdint>
#include "common.h"
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
namespace nunchaku::kernels {
namespace mma_helper {
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
};
__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#endif
return d;
}
template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 d;
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
return d;
}
#endif
template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#endif
return d;
}
template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) = delete;
template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2)
);
#endif
return d;
}
template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2)
);
#endif
return d;
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -30,7 +30,11 @@ void gemm_w4a4(
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
);
void linearattn_vk_mul_q(Tensor q, Tensor vk);
......@@ -57,4 +61,19 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N]
// );
void attention_fp16(
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
);
// EXPERIMENTAL, for sm_75
void set_faster_i2f_mode(std::string mode);
// FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -3,11 +3,12 @@ import os
import random
import datasets
import yaml
from PIL import Image
_CITATION = """\
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
eprint={2402.17245},
......@@ -17,7 +18,7 @@ _CITATION = """\
"""
_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
......@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
CONTROL_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/MJHQ-5000.zip"
class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
......@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig):
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
class MJHQ(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
BUILDER_CONFIGS = [
MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset"),
MJHQConfig(name="MJHQ-control", version=VERSION, description="MJHQ-5K with controls"),
]
DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self):
......@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder):
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
"canny_image_path": datasets.Value("string"),
"cropped_image_path": datasets.Value("string"),
"depth_image_path": datasets.Value("string"),
"mask_image_path": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
......@@ -71,36 +81,75 @@ class DCI(datasets.GeneratorBasedBuilder):
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
if self.config.name == "MJHQ":
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
else:
assert self.config.name == "MJHQ-control"
control_root = dl_manager.download_and_extract(CONTROL_URL)
control_root = os.path.join(control_root, "MJHQ-5000")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"meta_path": os.path.join(control_root, "prompts.yaml"), "image_root": control_root},
),
]
def _generate_examples(self, meta_path: str, image_root: str):
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, {
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
}
if self.config.name == "MJHQ":
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, {
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
"canny_image_path": None,
"cropped_image_path": None,
"depth_image_path": None,
"mask_image_path": None,
}
else:
assert self.config.name == "MJHQ-control"
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
prompt = meta[name]
yield i, {
"filename": name,
"category": None,
"image": None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": os.path.join(image_root, "images", f"{name}.png"),
"split": self.config.name,
"canny_image_path": os.path.join(image_root, "canny_images", f"{name}.png"),
"cropped_image_path": os.path.join(image_root, "cropped_images", f"{name}.png"),
"depth_image_path": os.path.join(image_root, "depth_images", f"{name}.png"),
"mask_image_path": os.path.join(image_root, "mask_images", f"{name}.png"),
}
......@@ -3,9 +3,16 @@ import random
import datasets
import yaml
from huggingface_hub import snapshot_download
from nunchaku.utils import fetch_or_download
__all__ = ["get_dataset"]
__all__ = ["get_dataset", "load_dataset_yaml", "download_hf_dataset"]
def download_hf_dataset(repo_id: str = "mit-han-lab/nunchaku-test", local_dir: str | None = None) -> str:
path = snapshot_download(repo_id=repo_id, repo_type="dataset", local_dir=local_dir)
return path
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
......@@ -46,10 +53,13 @@ def get_dataset(
path = os.path.join(prefix, f"{name}")
if name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
elif name == "MJHQ-control":
kwargs["name"] = "MJHQ-control"
dataset = datasets.load_dataset(os.path.join(prefix, "MJHQ"), return_gt=return_gt, **kwargs)
else:
dataset = datasets.Dataset.from_dict(
load_dataset_yaml(
fetch_or_download(f"mit-han-lab/nunchaku-test/{name}.yaml", repo_type="dataset"),
fetch_or_download(f"mit-han-lab/svdquant-datasets/{name}.yaml", repo_type="dataset"),
max_dataset_size=max_dataset_size,
repeat=1,
),
......
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
@pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(0.12, 1024, 1024, 30, None, 1, 0.26),
(0.12, 512, 2048, 30, "anime", 1, 0.4),
],
)
def test_flux_dev_loras(
cache_threshold: float,
height: int,
width: int,
num_inference_steps: int,
lora_name: str,
lora_strength: float,
expected_lpips: float,
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_names=lora_name,
lora_strengths=lora_strength,
cache_threshold=cache_threshold,
expected_lpips=expected_lpips,
)
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
@pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, 50, "flashattn2", False, 0.226),
(2048, 512, 25, "nunchaku-fp16", False, 0.243),
],
)
def test_flux_dev(
height: int, width: int, num_inference_steps: int, attention_impl: str, cpu_offload: bool, expected_lpips: float
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
height=height,
width=width,
num_inference_steps=num_inference_steps,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, True, 0.178),
(25, "ghibsky", 1, False, 0.164),
(28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.223),
(28, "yarn", 1, False, 0.211),
(25, "haunted_linework", 1, True, 0.317),
],
)
def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name=lora_name,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_names=lora_name,
lora_strengths=lora_strength,
cache_threshold=0,
expected_lpips=expected_lpips,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_hypersd8_1536x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=1536,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="hypersd8",
lora_strengths=0.125,
cache_threshold=0,
expected_lpips=0.291,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_2048x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=2048,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="turbo8",
lora_strengths=1,
cache_threshold=0,
expected_lpips=0.189,
)
# lora composition
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_2048x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="yarn",
height=2048,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["turbo8", "yarn"],
lora_strengths=[1, 1],
cache_threshold=0,
expected_lpips=0.252,
)
# large rank loras
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_1024x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="ghibsky",
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0,
expected_lpips=0.44,
)
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit",
[
(False, False, 17),
(False, True, 13),
(True, False, 12),
(True, True, 6),
],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
precision = get_precision()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
)
memory = torch.cuda.max_memory_reserved(0) / 1024**3
assert memory < memory_limit
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, "flashattn2", False, 0.250),
(1024, 1024, "nunchaku-fp16", False, 0.255),
(1024, 1024, "flashattn2", True, 0.250),
(1920, 1080, "nunchaku-fp16", False, 0.253),
(2048, 2048, "flashattn2", True, 0.274),
],
)
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
precision=get_precision(),
height=height,
width=width,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
import pytest
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
def test_flux_dev_canny():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_canny_dev():
run_test(
precision=get_precision(),
model_name="flux.1-canny-dev",
dataset_name="MJHQ-control",
task="canny",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.103 if get_precision() == "int4" else 0.164,
)
processor = CannyDetector()
control_image = processor(
control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
)
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0
).images[0]
image.save("flux.1-canny-dev.png")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_depth_dev():
run_test(
precision=get_precision(),
model_name="flux.1-depth-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.170 if get_precision() == "int4" else 0.120,
)
def test_flux_dev_depth():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_fill_dev():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.045,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_canny_lora():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="canny",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
lora_names="canny",
lora_strengths=0.85,
cache_threshold=0,
expected_lpips=0.103,
)
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=30, guidance_scale=10.0
).images[0]
image.save("flux.1-depth-dev.png")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_depth_lora():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="depth",
lora_strengths=0.85,
expected_lpips=0.163,
)
def test_flux_dev_fill():
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
prompt="A wooden basket of a cat.",
image=image,
mask_image=mask,
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_fill_dev_turbo():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
).images[0]
image.save("flux.1-fill-dev.png")
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="turbo8",
lora_strengths=1,
expected_lpips=0.048,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_redux():
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images
images[0].save("flux.1-redux-dev.png")
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="redux",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=2.5,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.198 if get_precision() == "int4" else 0.55, # redux seems to generate different images on 5090
)
import pytest
from .utils import run_test
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[(1024, 1024, "flashattn2", False, 0.25), (2048, 512, "nunchaku-fp16", False, 0.25)],
)
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
precision=get_precision(),
model_name="shuttle-jaguar",
height=height,
width=width,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
import os
import tempfile
import pytest
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from safetensors.torch import save_file
from tqdm import tqdm
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux import comfyui2diffusers, convert_to_nunchaku_flux_lowrank_dict, detect_format, xlab2diffusers
from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
def run_pipeline(dataset, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
seed = hash_str_to_int(filename)
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
@pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips",
[
("int4", 1024, 1024, 4, 0, False, False, 16, 0.258),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1920, 1080, 4, 0, False, False, 16, 0.258),
("int4", 600, 800, 4, 0, False, False, 16, 0.29),
],
)
def test_flux_schnell(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
max_dataset_size: int,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "schnell", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
}
def run_test_flux_dev(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
lora_name: str | None,
lora_scale: float,
max_dataset_size: int,
expected_lpips: float,
):
save_root = os.path.join(
"results",
"dev",
f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
+ ("-qencoder" if use_qencoder else "")
+ (f"-{lora_name}_{lora_scale:.1f}" if lora_name else ""),
)
dataset = get_dataset(
name="MJHQ" if lora_name in [None, "hypersd8"] else lora_name, max_dataset_size=max_dataset_size
)
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
if lora_name is not None:
pipeline.load_lora_weights(
os.path.dirname(LORA_PATH_MAP[lora_name]),
weight_name=os.path.basename(LORA_PATH_MAP[lora_name]),
adapter_name="lora",
)
for n, m in pipeline.transformer.named_modules():
if isinstance(m, lora.LoraLayer):
for name in m.scaling.keys():
m.scaling[name] = lora_scale
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}")
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-dev", precision="fp4", offload=cpu_offload
)
if lora_name is not None:
lora_path = LORA_PATH_MAP[lora_name]
lora_format = detect_format(lora_path)
if lora_format != "svdquant":
if lora_format == "comfyui":
input_lora = comfyui2diffusers(lora_path)
elif lora_format == "xlab":
input_lora = xlab2diffusers(lora_path)
elif lora_format == "diffusers":
input_lora = lora_path
else:
raise ValueError(f"Invalid LoRA format {lora_format}.")
state_dict = convert_to_nunchaku_flux_lowrank_dict(
"mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors", input_lora
)
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=True) as tmp_file:
save_file(state_dict, tmp_file.name)
transformer.update_lora_params(tmp_file.name)
else:
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_scale)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_flux_dev_base(cpu_offload: bool):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_name=None,
lora_scale=0,
max_dataset_size=8,
expected_lpips=0.16,
)
def test_flux_dev_qencoder_800x600():
run_test_flux_dev(
precision="int4",
height=800,
width=600,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=True,
cpu_offload=False,
lora_name=None,
lora_scale=0,
max_dataset_size=8,
expected_lpips=0.36,
)
def test_flux_dev_hypersd8_1080x1920():
run_test_flux_dev(
precision="int4",
height=1080,
width=1920,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_name="hypersd8",
lora_scale=0.125,
max_dataset_size=8,
expected_lpips=0.44,
)
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, False, 0.16),
(25, "ghibsky", 1, False, 0.16),
(28, "anime", 1, False, 0.27),
(24, "sketch", 1, False, 0.35),
(28, "yarn", 1, False, 0.22),
(25, "haunted_linework", 1, False, 0.34),
],
)
def test_flux_dev_loras(num_inference_steps, lora_name, lora_scale, cpu_offload, expected_lpips):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_name=lora_name,
lora_scale=lora_scale,
max_dataset_size=8,
expected_lpips=expected_lpips,
)
@pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit",
[
(False, False, 17),
(False, True, 13),
(True, False, 12),
(True, True, 6),
],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.reset_peak_memory_stats()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
)
memory = torch.cuda.max_memory_reserved(0) / 1024**3
assert memory < memory_limit
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
import pytest
from nunchaku.utils import get_precision
from .utils import run_test
@pytest.mark.skipif(get_precision() == "fp4", reason="Blackwell GPUs. Skip tests for Turing.")
@pytest.mark.parametrize(
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips",
[
(1024, 1024, 50, True, None, 0.253),
(1024, 1024, 50, True, "enabled", 0.258),
(1024, 1024, 50, True, "always", 0.257),
],
)
def test_flux_dev(
height: int, width: int, num_inference_steps: int, cpu_offload: bool, i2f_mode: str | None, expected_lpips: float
):
run_test(
precision=get_precision(),
dtype="fp16",
model_name="flux.1-dev",
height=height,
width=width,
num_inference_steps=num_inference_steps,
attention_impl="nunchaku-fp16",
cpu_offload=cpu_offload,
i2f_mode=i2f_mode,
expected_lpips=expected_lpips,
)
import os
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from tqdm import tqdm
import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from ..data import download_hf_dataset, get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
ORIGINAL_REPO_MAP = {
"flux.1-schnell": "black-forest-labs/FLUX.1-schnell",
"flux.1-dev": "black-forest-labs/FLUX.1-dev",
"shuttle-jaguar": "shuttleai/shuttle-jaguar",
"flux.1-canny-dev": "black-forest-labs/FLUX.1-Canny-dev",
"flux.1-depth-dev": "black-forest-labs/FLUX.1-Depth-dev",
"flux.1-fill-dev": "black-forest-labs/FLUX.1-Fill-dev",
}
NUNCHAKU_REPO_PATTERN_MAP = {
"flux.1-schnell": "mit-han-lab/svdq-{precision}-flux.1-schnell",
"flux.1-dev": "mit-han-lab/svdq-{precision}-flux.1-dev",
"shuttle-jaguar": "mit-han-lab/svdq-{precision}-shuttle-jaguar",
"flux.1-canny-dev": "mit-han-lab/svdq-{precision}-flux.1-canny-dev",
"flux.1-depth-dev": "mit-han-lab/svdq-{precision}-flux.1-depth-dev",
"flux.1-fill-dev": "mit-han-lab/svdq-{precision}-flux.1-fill-dev",
}
LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
"turbo8": "alimama-creative/FLUX.1-Turbo-Alpha/diffusion_pytorch_model.safetensors",
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
"canny": "black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors",
"depth": "black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors",
}
def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
if task == "canny":
processor = CannyDetector()
elif task == "depth":
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
elif task == "redux":
processor = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
else:
assert task in ["t2i", "fill"]
processor = None
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
_forward_kwargs = {k: v for k, v in forward_kwargs.items()}
if task == "canny":
assert forward_kwargs.get("height", 1024) == 1024
assert forward_kwargs.get("width", 1024) == 1024
control_image = load_image(row["canny_image_path"])
control_image = processor(
control_image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024,
)
_forward_kwargs["control_image"] = control_image
elif task == "depth":
control_image = load_image(row["depth_image_path"])
control_image = processor(control_image)[0].convert("RGB")
_forward_kwargs["control_image"] = control_image
elif task == "fill":
image = load_image(row["image_path"])
mask_image = load_image(row["mask_image_path"])
_forward_kwargs["image"] = image
_forward_kwargs["mask_image"] = mask_image
elif task == "redux":
image = load_image(row["image_path"])
_forward_kwargs.update(processor(image))
seed = hash_str_to_int(filename)
if task == "redux":
image = pipeline(generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
else:
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
torch.cuda.empty_cache()
def run_test(
precision: str = "int4",
model_name: str = "flux.1-schnell",
dataset_name: str = "MJHQ",
task: str = "t2i",
dtype: str | torch.dtype = torch.bfloat16, # the full precision dtype
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 3.5,
use_qencoder: bool = False,
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False,
cache_threshold: float = 0,
lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 20,
i2f_mode: str | None = None,
expected_lpips: float = 0.5,
):
if isinstance(dtype, str):
dtype_str = dtype
if dtype == "bf16":
dtype = torch.bfloat16
else:
assert dtype == "fp16"
dtype = torch.float16
else:
if dtype == torch.bfloat16:
dtype_str = "bf16"
else:
assert dtype == torch.float16
dtype_str = "fp16"
dataset = get_dataset(name=dataset_name, max_dataset_size=max_dataset_size)
model_id_16bit = ORIGINAL_REPO_MAP[model_name]
folder_name = f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
if lora_names is None:
lora_names = []
elif isinstance(lora_names, str):
lora_names = [lora_names]
if len(lora_names) > 0:
if isinstance(lora_strengths, (int, float)):
lora_strengths = [lora_strengths]
assert len(lora_names) == len(lora_strengths)
for lora_name, lora_strength in zip(lora_names, lora_strengths):
folder_name += f"-{lora_name}_{lora_strength}"
if not os.path.exists(os.path.join("test_results", "ref")):
ref_root = download_hf_dataset(local_dir=os.path.join("test_results", "ref"))
else:
ref_root = os.path.join("test_results", "ref")
save_dir_16bit = os.path.join(ref_root, dtype_str, model_name, folder_name)
if task in ["t2i", "redux"]:
pipeline_cls = FluxPipeline
elif task in ["canny", "depth"]:
pipeline_cls = FluxControlPipeline
elif task == "fill":
pipeline_cls = FluxFillPipeline
else:
raise NotImplementedError(f"Unknown task {task}!")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {}
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
pipeline = pipeline.to("cuda")
if len(lora_names) > 0:
for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)):
lora_path = LORA_PATH_MAP[lora_name]
pipeline.load_lora_weights(
os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name=f"lora_{i}"
)
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
precision_str = precision
if use_qencoder:
precision_str += "-qe"
if attention_impl == "flashattn2":
precision_str += "-fa2"
else:
assert attention_impl == "nunchaku-fp16"
precision_str += "-nfp16"
if cpu_offload:
precision_str += "-co"
if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}"
if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}"
save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
model_id_4bit = NUNCHAKU_REPO_PATTERN_MAP[model_name].format(precision=precision)
if i2f_mode is not None:
nunchaku._C.utils.set_faster_i2f_mode(i2f_mode)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
model_id_4bit, offload=cpu_offload, torch_dtype=dtype
)
transformer.set_attention_impl(attention_impl)
if len(lora_names) > 0:
if len(lora_names) == 1: # directly load the lora
lora_path = LORA_PATH_MAP[lora_names[0]]
lora_strength = lora_strengths[0]
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_strength)
else:
composed_lora = compose_lora(
[
(LORA_PATH_MAP[lora_name], lora_strength)
for lora_name, lora_strength in zip(lora_names, lora_strengths)
]
)
transformer.update_lora_params(composed_lora)
pipeline_init_kwargs["transformer"] = transformer
if task == "redux":
pipeline_init_kwargs.update({"text_encoder": None, "text_encoder_2": None})
elif use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del transformer
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
# additional requirements for testing
pytest
datasets
torchmetrics
......
import pytest
import torch
from diffusers import SanaPAGPipeline, SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
......@@ -28,6 +31,7 @@ def test_sana():
image.save("sana_1600m.png")
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana_pag():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment