Commit 873a35be authored by muyangli's avatar muyangli
Browse files

v0.1.4 ready to release


Co-authored-by: default avatarZhekai Zhang <sxtyzhangzk@gmail.com>
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
Co-authored-by: default avatarYujun Lin <16437040+synxlin@users.noreply.github.com>
parent d9cd6858
......@@ -607,14 +607,22 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states };
}
FluxModel::FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device) {
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : offload(offload) {
for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
if (offload && i > 0) { // don't offload first block
transformer_blocks.back()->setLazyLoad(true);
transformer_blocks.back()->releaseLazyParams();
}
}
for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
if (offload) {
single_transformer_blocks.back()->setLazyLoad(true);
single_transformer_blocks.back()->releaseLazyParams();
}
}
}
......@@ -626,22 +634,51 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
for (auto &&block : transformer_blocks) {
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
}
const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
// txt first, same as diffusers
Tensor concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
}
hidden_states = concat;
encoder_hidden_states = {};
Tensor concat;
for (auto &&block : single_transformer_blocks) {
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
}
auto compute = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
} else {
if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers
concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
}
hidden_states = concat;
encoder_hidden_states = {};
}
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
}
};
auto load = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->loadLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->loadLazyParams();
}
};
auto unload = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->releaseLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->releaseLazyParams();
}
};
LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
helper.run();
return hidden_states;
}
\ No newline at end of file
......@@ -128,10 +128,13 @@ private:
class FluxModel : public Module {
public:
FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device);
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
private:
bool offload;
};
\ No newline at end of file
......@@ -16,7 +16,7 @@ GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::Sca
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
registerParams
(weight, "weight")
(weight, "weight", ParamFlags::LazyLoad)
(bias, "bias")
;
}
......@@ -27,7 +27,7 @@ Tensor GEMM_F16::forward(Tensor x) {
}
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f)
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f), device(device)
{
this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
......@@ -39,7 +39,7 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(wzeros, "wzeros")
(bias, "bias")
......@@ -52,7 +52,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
if (key == "lora_down") {
const int new_rank = dst.shape[0];
this->lora_rank = new_rank;
......@@ -100,7 +100,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
use_fp4(use_fp4),
lora_rank(0), dtype(dtype)
lora_rank(0), dtype(dtype), device(device)
{
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
if (use_fp4) {
......@@ -124,7 +124,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
......@@ -143,7 +143,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
this->lora_rank = dst.shape[1];
this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
} else {
......@@ -152,7 +152,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (key == "wcscales") {
assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad);
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
......@@ -242,15 +242,15 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// shape[-1] = out_features;
auto shape = TensorShape(qact.actShape.dataExtent);
shape[-1] = out_features;
out = Tensor::allocate(shape, dtype, qweight.device());
out = Tensor::allocate(shape, dtype, device);
} else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device());
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) {
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device);
}
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape;
......@@ -363,13 +363,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
// shape[-1] = in_features / 2;
QuantizedActivation qact;
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device());
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) {
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device);
}
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent;
......@@ -420,7 +420,7 @@ GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::Scala
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
;
......
......@@ -36,6 +36,7 @@ public:
int lora_rank;
float lora_scale;
const Device device;
public:
Tensor qweight;
Tensor wscales;
......@@ -86,6 +87,7 @@ public:
std::vector<float> lora_scales; // every 16 ranks share a scale
const Tensor::ScalarType dtype;
const Device device;
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) override;
......
......@@ -9,10 +9,20 @@ protected:
enum class ParamFlags : int {
None = 0,
Optional = 1,
LazyLoad = 2,
};
struct TensorLazyLoadInfo {
TensorShape shape;
Tensor::ScalarType type;
Device device;
Tensor src;
};
struct Param {
Tensor *tensor;
ParamFlags flags;
Tensor *tensor = nullptr;
ParamFlags flags = ParamFlags::None;
TensorLazyLoadInfo lazyInfo;
};
friend inline ParamFlags operator|(ParamFlags lhs, ParamFlags rhs) {
......@@ -21,6 +31,9 @@ protected:
friend inline ParamFlags operator&(ParamFlags lhs, ParamFlags rhs) {
return static_cast<ParamFlags>(static_cast<int>(lhs) & static_cast<int>(rhs));
}
static bool checkFlag(ParamFlags flags, ParamFlags target) {
return int(flags & target);
}
public:
std::string getFullName() const {
......@@ -35,6 +48,12 @@ public:
}
}
std::string getPrefix() const {
std::string fullName = getFullName();
std::string prefix = fullName.empty() ? "" : fullName + ".";
return prefix;
}
void traverse(std::function<void(Module *)> func) {
func(this);
for (Module *c : this->children) {
......@@ -46,8 +65,7 @@ public:
for (Module *c : children) {
c->loadParams(provider, partial);
}
std::string fullName = getFullName();
std::string prefix = fullName.empty() ? "" : fullName + ".";
std::string prefix = getPrefix();
for (auto &&[key, param] : params) {
Tensor src = provider.getTensor(prefix + key);
if (!src.valid()) {
......@@ -56,6 +74,13 @@ public:
}
throw std::runtime_error(spdlog::fmt_lib::format("Tensor {} not found", prefix + key));
}
if (enabledLazyLoad && checkFlag(param.flags, ParamFlags::LazyLoad)) {
param.lazyInfo.src = src;
if (!param.tensor->valid()) {
continue;
}
// keep loading params if param is not released
}
this->loadParam(key, *param.tensor, src);
// tensor->copy_(src);
}
......@@ -66,7 +91,46 @@ public:
this->name = std::move(name);
}
void loadLazyParams() {
traverse([](Module *m) {
for (auto &&[key, param] : m->params) {
if (!checkFlag(param.flags, ParamFlags::LazyLoad)) {
continue;
}
TensorLazyLoadInfo &lazy = param.lazyInfo;
Tensor &dst = *param.tensor;
Tensor src = lazy.src;
if (dst.valid()) {
continue;
}
dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device);
if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) {
throw std::runtime_error(spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key));
}
m->loadParam(key, dst, src);
}
});
}
void releaseLazyParams() {
traverse([](Module *m) {
if (!m->enabledLazyLoad) {
return;
}
for (auto &&[key, param] : m->params) {
if (checkFlag(param.flags, ParamFlags::LazyLoad)) {
*param.tensor = Tensor{};
}
}
});
}
void setLazyLoad(bool val) {
traverse([val](Module *m) {
m->enabledLazyLoad = val;
});
}
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) {
......@@ -98,6 +162,13 @@ protected:
if (param.valid()) {
params[name].tensor = &param;
params[name].flags = flags;
if (checkFlag(flags, ParamFlags::LazyLoad) && param.valid()) {
TensorLazyLoadInfo &lazy = params[name].lazyInfo;
lazy.shape = param.shape;
lazy.type = param.dtype();
lazy.device = param.device();
}
}
return ParamsRegisterHelper(*this);
}
......@@ -121,4 +192,78 @@ public:
std::string name = "";
std::vector<Module *> children;
std::map<std::string, Param> params;
bool enabledLazyLoad = false;
};
struct LayerOffloadHelper {
using func_t = std::function<void(int)>;
const bool offload;
const int numLayers;
func_t funcCompute, funcLoad, funcUnload;
std::unique_ptr<CUDAStreamWrapper> streamCompute;
std::unique_ptr<CUDAStreamWrapper> streamLoad;
std::unique_ptr<CUDAEventWrapper> eventComputeDone;
std::unique_ptr<CUDAEventWrapper> eventLoadDone;
LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload)
: offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload)
{
if (offload) {
streamCompute = std::make_unique<CUDAStreamWrapper>();
streamLoad = std::make_unique<CUDAStreamWrapper>();
}
}
void run() {
for (int i = 0; i < numLayers; i++) {
run(i);
}
waitEvent(eventComputeDone.get());
funcUnload(numLayers - 1);
}
private:
void run(int layer) {
if (!offload) {
funcCompute(layer);
} else {
std::unique_ptr<CUDAEventWrapper> nextComputeDone, nextLoadDone;
// issue compute kernels first so that we could still overlap compute and memcpy if memory is not pinned
{
CUDAStreamContext ctx(streamCompute->stream);
waitEvent(eventLoadDone.get());
funcCompute(layer);
nextComputeDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(cudaEventRecord(nextComputeDone->event, getCurrentCUDAStream()));
}
{
CUDAStreamContext ctx(streamLoad->stream);
waitEvent(eventComputeDone.get());
if (layer - 1 > 0) {
funcUnload(layer - 1);
}
if (layer + 1 < numLayers) {
funcLoad(layer + 1);
}
nextLoadDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(cudaEventRecord(nextLoadDone->event, getCurrentCUDAStream()));
}
eventComputeDone = std::move(nextComputeDone);
eventLoadDone = std::move(nextLoadDone);
}
}
static void waitEvent(CUDAEventWrapper *event) {
if (!event) {
return;
}
checkCUDA(cudaStreamWaitEvent(getCurrentCUDAStream(), event->event));
}
};
\ No newline at end of file
......@@ -28,6 +28,33 @@ private:
mio::mmap_source impl;
};
class SafeTensors::MMapImplRead : public SafeTensors::MMapImpl {
public:
MMapImplRead(const std::string &filename, bool pin) {
std::ifstream fin(filename, std::ios::binary);
fin.seekg(0, std::ios::end);
size_t size = fin.tellg();
fin.seekg(0);
if (pin) {
buffer = std::make_unique<BufferHost>(size);
} else {
buffer = std::make_unique<BufferMalloc>(size);
}
fin.read((char *)buffer->getPtr(), size);
}
virtual size_t size() override {
return buffer->getSize();
}
virtual const char *data() override {
return (const char *)buffer->getPtr();
}
private:
std::unique_ptr<Buffer> buffer;
};
#ifdef __linux__
#include <unistd.h>
......@@ -89,26 +116,78 @@ public:
#endif
SafeTensors::SafeTensors(const std::string &filename) {
this->mapped = std::make_unique<MMapImplMio>(filename);
this->hostRegistered = false;
this->memoryPinned = false;
if (cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly) != cudaSuccess) {
spdlog::warn("Unable to pin memory: {}", cudaGetErrorString(cudaGetLastError()));
// mlock(const_cast<char *>(this->mapped->data()), this->mapped->size());
#ifdef __linux__
spdlog::info("Try MAP_PRIVATE");
this->mapped.reset();
auto methodPrivate = [&]() {
this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
this->hostRegistered = true;
this->memoryPinned = true;
};
auto methodMio = [&]() {
this->mapped = std::make_unique<MMapImplMio>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable | cudaHostRegisterReadOnly));
this->hostRegistered = true;
this->memoryPinned = true;
};
auto methodRead = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, true);
this->memoryPinned = true;
};
auto methodReadNopin = [&]() {
this->mapped = std::make_unique<MMapImplRead>(filename, false);
};
const std::map<std::string, std::function<void()>> methods = {
{ "PRIVATE", methodPrivate },
{ "MIO", methodMio },
{ "READ", methodRead },
{ "READNOPIN", methodReadNopin },
};
auto tryMethod = [&](std::string name) {
spdlog::debug("Trying to load safetensors using method {}", name);
this->mapped.reset();
try {
methods.at(name)();
return true;
} catch (std::exception &e) {
spdlog::warn("Failed to load safetensors using method {}: {}", name, e.what());
}
return false;
};
if (char *env = getenv("NUNCHAKU_LOAD_METHOD")) {
std::string method = std::string(env);
tryMethod(method);
} else {
#ifdef __linux__
tryMethod("PRIVATE") || tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN");
#else
tryMethod("MIO") || tryMethod("READ") || tryMethod("READNOPIN");
#endif
}
if (!this->mapped) {
throw std::runtime_error("Failed to load safetensors");
}
if (!this->memoryPinned) {
spdlog::warn("Memory not pinned");
}
parseHeader();
}
SafeTensors::~SafeTensors() {
#ifndef _WIN32
checkCUDA(cudaHostUnregister(const_cast<char *>(this->mapped->data())));
#endif
if (this->hostRegistered) {
if (cudaHostUnregister(const_cast<char *>(this->mapped->data())) != cudaSuccess) {
spdlog::warn("cudaHostUnregister failed: {}", cudaGetErrorString(cudaGetLastError()));
}
}
}
void SafeTensors::parseHeader() {
......
......@@ -44,6 +44,7 @@ private:
class MMapImpl;
class MMapImplMio;
class MMapImplPrivate;
class MMapImplRead;
struct TensorInfo {
TensorShape shape;
......@@ -54,4 +55,6 @@ private:
};
std::map<std::string, TensorInfo> tensors;
std::unique_ptr<MMapImpl> mapped;
bool hostRegistered, memoryPinned;
};
\ No newline at end of file
......@@ -85,14 +85,15 @@ public:
if (size == 0) {
this->ptr = nullptr;
}
checkCUDA(cudaMallocAsync(&this->ptr, size, 0)); // use default stream to sync with all other streams
// TODO: buffer used in multiple streams?
checkCUDA(cudaMallocAsync(&this->ptr, size, getCurrentCUDAStream()));
}
virtual ~BufferCUDA() {
if (this->size == 0) {
assert(!this->ptr);
return;
}
checkCUDA(cudaFreeAsync(this->ptr, 0));
checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
}
virtual bool isAsyncBuffer() override {
return true;
......@@ -361,7 +362,7 @@ public:
Tensor &zero_() {
assert(this->is_contiguous());
checkCUDA(cudaMemset(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size()));
checkCUDA(cudaMemsetAsync(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this;
}
Tensor &copy_(Tensor other) {
......
......@@ -63,6 +63,49 @@ inline cudaStream_t getCurrentCUDAStream() {
return stackCUDAStreams.top();
}
struct CUDAStreamContext {
cudaStream_t stream;
CUDAStreamContext(cudaStream_t stream) : stream(stream) {
stackCUDAStreams.push(stream);
}
CUDAStreamContext(const CUDAStreamContext &) = delete;
CUDAStreamContext(CUDAStreamContext &&) = delete;
~CUDAStreamContext() {
assert(stackCUDAStreams.top() == stream);
stackCUDAStreams.pop();
}
};
struct CUDAStreamWrapper {
cudaStream_t stream;
CUDAStreamWrapper() {
checkCUDA(cudaStreamCreate(&stream));
}
CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
~CUDAStreamWrapper() {
checkCUDA(cudaStreamDestroy(stream));
}
};
struct CUDAEventWrapper {
cudaEvent_t event;
CUDAEventWrapper(unsigned int flags = cudaEventDefault) {
checkCUDA(cudaEventCreateWithFlags(&event, flags));
}
CUDAEventWrapper(const CUDAEventWrapper &) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
~CUDAEventWrapper() {
checkCUDA(cudaEventDestroy(event));
}
};
inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local cudaDeviceProp prop;
static thread_local bool propAvailable = false;
......
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "semaphore.h"
#include "gemm_cuda.h"
#include "gemm_awq.h"
//#include "../../../nunchaku/csrc/quantization/dequantize.cuh"
#include "dequantize.cuh"
#include <stdio.h>
......@@ -30,8 +30,8 @@
#endif
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
Tensor _semaphores = Tensor::empty({num_mn_tiles}, Tensor::INT32, _in_feats.device()); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
......@@ -99,7 +99,7 @@ __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, u
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
__asm__ __volatile__(
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
......@@ -111,7 +111,7 @@ __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
__asm__ __volatile__(
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
......@@ -137,7 +137,7 @@ __device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16
template <>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp)
{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
......@@ -147,7 +147,7 @@ __device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp
template <>
__device__ __inline__ void mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp)
{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
......
......@@ -307,7 +307,7 @@ Tensor gemv_awq(
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads>>>(
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
checkCUDA(cudaGetLastError());
......
......@@ -1440,10 +1440,10 @@ public:
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
__device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon) {
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
......@@ -1470,9 +1470,9 @@ public:
CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, INT_MAX, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, maxRows - warpId * WARP_M, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope
pack_rope_t rope;
pack_rope_t rope;
if (laneId < LANES_PER_HEAD) {
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
......@@ -1508,7 +1508,7 @@ public:
// rope
for (int i = 0; i < PACK_SIZE; i += 2) {
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq");
CHECK_NAN(freq[i+1].x, "rope.freq");
......@@ -1519,7 +1519,7 @@ public:
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
......@@ -1579,7 +1579,7 @@ public:
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] /= PoolSize;
}
store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp);
}
__syncthreads();
......@@ -1599,13 +1599,14 @@ public:
if (is_q || is_k) {
apply(
fpsum,
fpsum,
args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N,
M, N, K,
M, N, K,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon
args.epsilon,
args.actualM - bm * BLOCK_M
);
} else {
EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{
......
......@@ -5,8 +5,13 @@ namespace nunchaku::kernels {
template<typename Config>
class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>;
using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96>;
// using LoraRanks = std::integer_sequence<int, 32>;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 160, 176, 224>;
// using LoraRanks = std::integer_sequence<int,
// 0, 32, 48, 64, 80, 96, 112, 128, 144, 160,
// 176, 192, 208, 224, 240, 256, 272, 288, 304, 320,
// 336, 352, 368, 384, 400, 416, 432, 448, 464, 480,
// 496, 512>;
using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t;
......
......@@ -97,7 +97,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
......@@ -134,7 +134,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
......@@ -375,7 +375,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
BLOCK_SIZE = 128;
}
invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE>>>(
invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
q.data_ptr<half_t>(),
vk.data_ptr<float>(),
1e-6f,
......@@ -428,7 +428,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
......@@ -462,7 +462,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te
assert(oscales.numel() == M * K / GEMM::WARP_K);
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE>>>(
invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(),
output.data_ptr<packed_act_t>(),
oscales.data_ptr<packed_ascale_t>(),
......@@ -486,7 +486,7 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Te
assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE>>>(
invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(),
output.data_ptr<packed_wgt_t>(),
oscales.data_ptr<packed_wscale_t>(),
......
import json
import os
import random
import datasets
from PIL import Image
_CITATION = """\
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
eprint={2402.17245},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
"""
_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
_HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K"
_LICENSE = (
"Playground v2.5 Community License "
"(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)"
)
IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip"
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
super(MJHQConfig, self).__init__(
name=kwargs.get("name", "default"),
version=kwargs.get("version", "0.0.0"),
data_dir=kwargs.get("data_dir", None),
data_files=kwargs.get("data_files", None),
description=kwargs.get("description", None),
)
self.max_dataset_size = max_dataset_size
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self):
features = datasets.Features(
{
"filename": datasets.Value("string"),
"category": datasets.Value("string"),
"image": datasets.Image(),
"prompt": datasets.Value("string"),
"prompt_path": datasets.Value("string"),
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
def _generate_examples(self, meta_path: str, image_root: str):
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, {
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
}
import os
import random
import datasets
import yaml
from nunchaku.utils import fetch_or_download
__all__ = ["get_dataset"]
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[:max_dataset_size]
names = sorted(names)
ret = {"filename": [], "prompt": [], "meta_path": []}
idx = 0
for name in names:
prompt = meta[name]
for j in range(repeat):
ret["filename"].append(f"{name}-{j}")
ret["prompt"].append(prompt)
ret["meta_path"].append(meta_path)
idx += 1
return ret
def get_dataset(
name: str,
config_name: str | None = None,
split: str = "train",
return_gt: bool = False,
max_dataset_size: int = 5000,
) -> datasets.Dataset:
prefix = os.path.dirname(__file__)
kwargs = {
"name": config_name,
"split": split,
"trust_remote_code": True,
"token": True,
"max_dataset_size": max_dataset_size,
}
path = os.path.join(prefix, f"{name}")
if name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
else:
dataset = datasets.Dataset.from_dict(
load_dataset_yaml(
fetch_or_download(f"mit-han-lab/nunchaku-test/{name}.yaml", repo_type="dataset"),
max_dataset_size=max_dataset_size,
repeat=1,
),
features=datasets.Features(
{
"filename": datasets.Value("string"),
"prompt": datasets.Value("string"),
"meta_path": datasets.Value("string"),
}
),
)
return dataset
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from nunchaku import NunchakuFluxTransformer2dModel
def test_flux_dev_canny():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
)
processor = CannyDetector()
control_image = processor(
control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
)
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0
).images[0]
image.save("flux.1-canny-dev.png")
def test_flux_dev_depth():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
)
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=30, guidance_scale=10.0
).images[0]
image.save("flux.1-depth-dev.png")
def test_flux_dev_fill():
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
prompt="A wooden basket of a cat.",
image=image,
mask_image=mask,
height=1024,
width=1024,
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
).images[0]
image.save("flux.1-fill-dev.png")
def test_flux_dev_redux():
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images
images[0].save("flux.1-redux-dev.png")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment