"vscode:/vscode.git/clone" did not exist on "9d4cf96d3f11de0b911f78b57914da9d303d3dd7"
Unverified Commit f060b8da authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[Major] Release v0.1.4

Support 4-bit text encoder and per-layer CPU offloading, reducing FLUX's minimum memory requirement to just 4 GiB while maintaining a 2–3× speedup. Fix various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
parents f549dfc6 873a35be
...@@ -28,6 +28,33 @@ private: ...@@ -28,6 +28,33 @@ private:
mio::mmap_source impl; mio::mmap_source impl;
}; };
class SafeTensors::MMapImplRead : public SafeTensors::MMapImpl {
public:
MMapImplRead(const std::string &filename, bool pin) {
std::ifstream fin(filename, std::ios::binary);
fin.seekg(0, std::ios::end);
size_t size = fin.tellg();
fin.seekg(0);
if (pin) {
buffer = std::make_unique<BufferHost>(size);
} else {
buffer = std::make_unique<BufferMalloc>(size);
}
fin.read((char *)buffer->getPtr(), size);
}
virtual size_t size() override {
return buffer->getSize();
}
virtual const char *data() override {
return (const char *)buffer->getPtr();
}
private:
std::unique_ptr<Buffer> buffer;
};
#ifdef __linux__ #ifdef __linux__
#include <unistd.h> #include <unistd.h>
...@@ -89,26 +116,78 @@ public: ...@@ -89,26 +116,78 @@ public:
#endif #endif
SafeTensors::SafeTensors(const std::string &filename) { SafeTensors::SafeTensors(const std::string &filename) {
this->mapped = std::make_unique<MMapImplMio>(filename); this->hostRegistered = false;
this->memoryPinned = false;
if (cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly) != cudaSuccess) { auto methodPrivate = [&]() {
spdlog::warn("Unable to pin memory: {}", cudaGetErrorString(cudaGetLastError()));
// mlock(const_cast<char *>(this->mapped->data()), this->mapped->size());
#ifdef __linux__
spdlog::info("Try MAP_PRIVATE");
this->mapped.reset();
this->mapped = std::make_unique<MMapImplPrivate>(filename); this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable)); checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
this->hostRegistered = true;
this->memoryPinned = true;
};
auto methodMio = [&]() {
this->mapped = std::make_unique<MMapImplMio>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly));
this->hostRegistered = true;
this->memoryPinned = true;
};
auto methodRead = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, true);
this->memoryPinned = true;
};
auto methodReadNopin = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, false);
};
const std::map<std::string, std::function<void()>> methods = {
{ "PRIVATE", methodPrivate },
{ "MIO", methodMio },
{ "READ", methodRead },
{ "READNOPIN", methodReadNopin },
};
auto tryMethod = [&](std::string name) {
spdlog::debug("Trying to load safetensors using method {}", name);
this->mapped.reset();
try {
methods.at(name)();
return true;
} catch (std::exception &e) {
spdlog::warn("Failed to load safetensors using method {}: {}", name, e.what());
}
return false;
};
if (char *env = getenv("NUNCHAKU_LOAD_METHOD")) {
std::string method = std::string(env);
tryMethod(method);
} else {
#ifdef __linux__
tryMethod("PRIVATE") || tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN");
#else
tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN");
#endif #endif
}
if (!this->mapped) {
throw std::runtime_error("Failed to load safetensors");
}
if (!this->memoryPinned) {
spdlog::warn("Memory not pinned");
} }
parseHeader(); parseHeader();
} }
SafeTensors::~SafeTensors() { SafeTensors::~SafeTensors() {
#ifndef _WIN32 if (this->hostRegistered) {
checkCUDA(cudaHostUnregister(const_cast<char *>(this->mapped->data()))); if (cudaHostUnregister(const_cast<char *>(this->mapped->data())) != cudaSuccess) {
#endif spdlog::warn("cudaHostUnregister failed: {}", cudaGetErrorString(cudaGetLastError()));
}
}
} }
void SafeTensors::parseHeader() { void SafeTensors::parseHeader() {
......
...@@ -44,6 +44,7 @@ private: ...@@ -44,6 +44,7 @@ private:
class MMapImpl; class MMapImpl;
class MMapImplMio; class MMapImplMio;
class MMapImplPrivate; class MMapImplPrivate;
class MMapImplRead;
struct TensorInfo { struct TensorInfo {
TensorShape shape; TensorShape shape;
...@@ -54,4 +55,6 @@ private: ...@@ -54,4 +55,6 @@ private:
}; };
std::map<std::string, TensorInfo> tensors; std::map<std::string, TensorInfo> tensors;
std::unique_ptr<MMapImpl> mapped; std::unique_ptr<MMapImpl> mapped;
bool hostRegistered, memoryPinned;
}; };
\ No newline at end of file
...@@ -85,14 +85,15 @@ public: ...@@ -85,14 +85,15 @@ public:
if (size == 0) { if (size == 0) {
this->ptr = nullptr; this->ptr = nullptr;
} }
checkCUDA(cudaMallocAsync(&this->ptr, size, 0)); // use default stream to sync with all other streams // TODO: buffer used in multiple streams?
checkCUDA(cudaMallocAsync(&this->ptr, size, getCurrentCUDAStream()));
} }
virtual ~BufferCUDA() { virtual ~BufferCUDA() {
if (this->size == 0) { if (this->size == 0) {
assert(!this->ptr); assert(!this->ptr);
return; return;
} }
checkCUDA(cudaFreeAsync(this->ptr, 0)); checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
return true; return true;
...@@ -217,7 +218,7 @@ class Tensor { ...@@ -217,7 +218,7 @@ class Tensor {
public: public:
enum ScalarType { enum ScalarType {
INVALID_SCALAR_TYPE, INVALID_SCALAR_TYPE,
INT8, INT32, INT64, INT8, INT16, INT32, INT64,
FP16, FP32, BF16, FP16, FP32, BF16,
FP8_E4M3, FP8_E5M2, FP8_E4M3, FP8_E5M2,
}; };
...@@ -361,7 +362,7 @@ public: ...@@ -361,7 +362,7 @@ public:
Tensor &zero_() { Tensor &zero_() {
assert(this->is_contiguous()); assert(this->is_contiguous());
checkCUDA(cudaMemset(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size())); checkCUDA(cudaMemsetAsync(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this; return *this;
} }
Tensor &copy_(Tensor other) { Tensor &copy_(Tensor other) {
...@@ -541,6 +542,7 @@ public: ...@@ -541,6 +542,7 @@ public:
inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = { inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{INT8, 1}, {INT8, 1},
{INT16, 2},
{INT32, 4}, {INT32, 4},
{INT64, 8}, {INT64, 8},
{FP16, 2}, {FP16, 2},
......
...@@ -63,6 +63,49 @@ inline cudaStream_t getCurrentCUDAStream() { ...@@ -63,6 +63,49 @@ inline cudaStream_t getCurrentCUDAStream() {
return stackCUDAStreams.top(); return stackCUDAStreams.top();
} }
struct CUDAStreamContext {
cudaStream_t stream;
CUDAStreamContext(cudaStream_t stream) : stream(stream) {
stackCUDAStreams.push(stream);
}
CUDAStreamContext(const CUDAStreamContext &) = delete;
CUDAStreamContext(CUDAStreamContext &&) = delete;
~CUDAStreamContext() {
assert(stackCUDAStreams.top() == stream);
stackCUDAStreams.pop();
}
};
struct CUDAStreamWrapper {
cudaStream_t stream;
CUDAStreamWrapper() {
checkCUDA(cudaStreamCreate(&stream));
}
CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
~CUDAStreamWrapper() {
checkCUDA(cudaStreamDestroy(stream));
}
};
struct CUDAEventWrapper {
cudaEvent_t event;
CUDAEventWrapper(unsigned int flags = cudaEventDefault) {
checkCUDA(cudaEventCreateWithFlags(&event, flags));
}
CUDAEventWrapper(const CUDAEventWrapper &) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
~CUDAEventWrapper() {
checkCUDA(cudaEventDestroy(event));
}
};
inline cudaDeviceProp *getCurrentDeviceProperties() { inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local cudaDeviceProp prop; static thread_local cudaDeviceProp prop;
static thread_local bool propAvailable = false; static thread_local bool propAvailable = false;
......
...@@ -28,6 +28,7 @@ Tensor from_torch(at::Tensor input) { ...@@ -28,6 +28,7 @@ Tensor from_torch(at::Tensor input) {
{ at::ScalarType::Float, Tensor::FP32 }, { at::ScalarType::Float, Tensor::FP32 },
{ at::ScalarType::Half, Tensor::FP16 }, { at::ScalarType::Half, Tensor::FP16 },
{ at::ScalarType::BFloat16, Tensor::BF16 }, { at::ScalarType::BFloat16, Tensor::BF16 },
{ at::ScalarType::Short, Tensor::INT16 },
{ at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 }, { at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 },
{ at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 }, { at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 },
}; };
...@@ -55,6 +56,7 @@ at::Tensor to_torch(Tensor input) { ...@@ -55,6 +56,7 @@ at::Tensor to_torch(Tensor input) {
{ Tensor::FP32, at::ScalarType::Float }, { Tensor::FP32, at::ScalarType::Float },
{ Tensor::FP16, at::ScalarType::Half }, { Tensor::FP16, at::ScalarType::Half },
{ Tensor::BF16, at::ScalarType::BFloat16 }, { Tensor::BF16, at::ScalarType::BFloat16 },
{ Tensor::INT16, at::ScalarType::Short },
{ Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn }, { Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn },
{ Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 }, { Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 },
}; };
......
This diff is collapsed.
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor awq_gemm_forward_cuda(
Tensor _in_feats,
Tensor _kernel,
Tensor _scales,
Tensor _zeros);
...@@ -307,7 +307,7 @@ Tensor gemv_awq( ...@@ -307,7 +307,7 @@ Tensor gemv_awq(
return; return;
} }
if constexpr (M > 0) { if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads>>>( gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n in_feats, kernel, scaling_factors, zeros, out_feats, k, n
); );
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
......
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
// namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore
{
public:
int *lock;
bool wait_thread;
int state;
public:
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),
wait_thread(thread_id < 0 || thread_id == 0),
state(-1)
{
}
/// Permit fetching the synchronization mechanism early
__device__ void fetch()
{
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
/// Gets the internal state
__device__ int get_state() const
{
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result
__device__ void release(int status = 0)
{
__syncthreads();
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// } // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -349,6 +349,29 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) ...@@ -349,6 +349,29 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
#endif // ENABLE BF16 #endif // ENABLE BF16
template <typename f16_t>
__device__ __forceinline__
packed_as<f16_t, 2>::type
f162f162(f16_t x);
template <>
__device__ __forceinline__
packed_as<half, 2>::type
f162f162<half>(half x)
{
return __half2half2(x);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type
f162f162<__nv_bfloat16>(__nv_bfloat16 x)
{
return __bfloat162bfloat162(x);
}
# endif
template <typename To, typename Ti> template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val) __device__ inline To cuda_sum(Ti val)
{ {
......
...@@ -1440,10 +1440,10 @@ public: ...@@ -1440,10 +1440,10 @@ public:
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
__device__ __forceinline__ __device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon) { static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128]; __shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE; constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
...@@ -1470,9 +1470,9 @@ public: ...@@ -1470,9 +1470,9 @@ public:
CHECK_NAN(fpsum, "fpsum"); CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, INT_MAX, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE { unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, maxRows - warpId * WARP_M, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope // load rope
pack_rope_t rope; pack_rope_t rope;
if (laneId < LANES_PER_HEAD) { if (laneId < LANES_PER_HEAD) {
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2])); // freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS])); rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
...@@ -1508,7 +1508,7 @@ public: ...@@ -1508,7 +1508,7 @@ public:
// rope // rope
for (int i = 0; i < PACK_SIZE; i += 2) { for (int i = 0; i < PACK_SIZE; i += 2) {
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1])); float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(freq[i].x, "rope.freq"); CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq"); CHECK_NAN(freq[i].y, "rope.freq");
CHECK_NAN(freq[i+1].x, "rope.freq"); CHECK_NAN(freq[i+1].x, "rope.freq");
...@@ -1519,7 +1519,7 @@ public: ...@@ -1519,7 +1519,7 @@ public:
// pack[i] = tmp.x; // pack[i] = tmp.x;
// pack[i+1] = tmp.y; // pack[i+1] = tmp.y;
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n", // printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId, // blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId, // blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y // (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
...@@ -1579,7 +1579,7 @@ public: ...@@ -1579,7 +1579,7 @@ public:
for (int j = 0; j < PACK_SIZE; j++) { for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] /= PoolSize; reduce_tmp[j] /= PoolSize;
} }
store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp); store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp);
} }
__syncthreads(); __syncthreads();
...@@ -1599,13 +1599,14 @@ public: ...@@ -1599,13 +1599,14 @@ public:
if (is_q || is_k) { if (is_q || is_k) {
apply( apply(
fpsum, fpsum,
args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N, args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N,
M, N, K, M, N, K,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr, args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS), args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k, is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon args.epsilon,
args.actualM - bm * BLOCK_M
); );
} else { } else {
EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{ EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{
......
...@@ -5,8 +5,13 @@ namespace nunchaku::kernels { ...@@ -5,8 +5,13 @@ namespace nunchaku::kernels {
template<typename Config> template<typename Config>
class GEMM_W4A4_Launch { class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>; using GEMM = GEMM_W4A4<Config>;
using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96>; // using LoraRanks = std::integer_sequence<int, 0, 32>;
// using LoraRanks = std::integer_sequence<int, 32>; using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 160, 176, 224>;
// using LoraRanks = std::integer_sequence<int,
// 0, 32, 48, 64, 80, 96, 112, 128, 144, 160,
// 176, 192, 208, 224, 240, 256, 272, 288, 304, 320,
// 336, 352, 368, 384, 400, 416, 432, 448, 464, 480,
// 496, 512>;
using packed_act_t = typename GEMM::packed_act_t; using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t; using packed_wgt_t = typename GEMM::packed_wgt_t;
......
...@@ -97,7 +97,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -97,7 +97,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert(alpha == 1.0f); assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(), act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(), wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(), ascales.data_ptr<packed_ascale_t>(),
...@@ -134,7 +134,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -134,7 +134,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert(ascales.dtype() == Tensor::FP8_E4M3); assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3); assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(), act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(), wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(), ascales.data_ptr<packed_amscale_t>(),
...@@ -375,7 +375,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) { ...@@ -375,7 +375,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
BLOCK_SIZE = 128; BLOCK_SIZE = 128;
} }
invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE>>>( invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
q.data_ptr<half_t>(), q.data_ptr<half_t>(),
vk.data_ptr<float>(), vk.data_ptr<float>(),
1e-6f, 1e-6f,
...@@ -428,7 +428,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -428,7 +428,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel())); // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{ typename kernel::Arguments{
.input = input.data_ptr<half_t>(), .input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr, .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
...@@ -462,7 +462,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te ...@@ -462,7 +462,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te
assert(oscales.numel() == M * K / GEMM::WARP_K); assert(oscales.numel() == M * K / GEMM::WARP_K);
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K); dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE>>>( invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(), input.data_ptr<half_t>(),
output.data_ptr<packed_act_t>(), output.data_ptr<packed_act_t>(),
oscales.data_ptr<packed_ascale_t>(), oscales.data_ptr<packed_ascale_t>(),
...@@ -486,7 +486,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Te ...@@ -486,7 +486,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Te
assert(oscales.numel() == N * K / GEMM::WARP_K); assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K); dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE>>>( invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(), input.data_ptr<half_t>(),
output.data_ptr<packed_wgt_t>(), output.data_ptr<packed_wgt_t>(),
oscales.data_ptr<packed_wscale_t>(), oscales.data_ptr<packed_wscale_t>(),
......
import json
import os
import random
import datasets
from PIL import Image
_CITATION = """\
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
eprint={2402.17245},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
"""
_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
_HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K"
_LICENSE = (
"Playground v2.5 Community License "
"(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)"
)
IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip"
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
super(MJHQConfig, self).__init__(
name=kwargs.get("name", "default"),
version=kwargs.get("version", "0.0.0"),
data_dir=kwargs.get("data_dir", None),
data_files=kwargs.get("data_files", None),
description=kwargs.get("description", None),
)
self.max_dataset_size = max_dataset_size
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self):
features = datasets.Features(
{
"filename": datasets.Value("string"),
"category": datasets.Value("string"),
"image": datasets.Image(),
"prompt": datasets.Value("string"),
"prompt_path": datasets.Value("string"),
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
def _generate_examples(self, meta_path: str, image_root: str):
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, {
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
}
import os
import random
import datasets
import yaml
from nunchaku.utils import fetch_or_download
__all__ = ["get_dataset"]
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[:max_dataset_size]
names = sorted(names)
ret = {"filename": [], "prompt": [], "meta_path": []}
idx = 0
for name in names:
prompt = meta[name]
for j in range(repeat):
ret["filename"].append(f"{name}-{j}")
ret["prompt"].append(prompt)
ret["meta_path"].append(meta_path)
idx += 1
return ret
def get_dataset(
name: str,
config_name: str | None = None,
split: str = "train",
return_gt: bool = False,
max_dataset_size: int = 5000,
) -> datasets.Dataset:
prefix = os.path.dirname(__file__)
kwargs = {
"name": config_name,
"split": split,
"trust_remote_code": True,
"token": True,
"max_dataset_size": max_dataset_size,
}
path = os.path.join(prefix, f"{name}")
if name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
else:
dataset = datasets.Dataset.from_dict(
load_dataset_yaml(
fetch_or_download(f"mit-han-lab/nunchaku-test/{name}.yaml", repo_type="dataset"),
max_dataset_size=max_dataset_size,
repeat=1,
),
features=datasets.Features(
{
"filename": datasets.Value("string"),
"prompt": datasets.Value("string"),
"meta_path": datasets.Value("string"),
}
),
)
return dataset
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from nunchaku import NunchakuFluxTransformer2dModel
def test_flux_dev_canny():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
)
processor = CannyDetector()
control_image = processor(
control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
)
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0
).images[0]
image.save("flux.1-canny-dev.png")
def test_flux_dev_depth():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
)
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=30, guidance_scale=10.0
).images[0]
image.save("flux.1-depth-dev.png")
def test_flux_dev_fill():
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
prompt="A wooden basket of a cat.",
image=image,
mask_image=mask,
height=1024,
width=1024,
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
).images[0]
image.save("flux.1-fill-dev.png")
def test_flux_dev_redux():
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images
images[0].save("flux.1-redux-dev.png")
This diff is collapsed.
pytest
datasets
torchmetrics
mediapipe
controlnet_aux
peft
git+https://github.com/asomoza/image_gen_aux.git
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment