Commit 873a35be authored by muyangli's avatar muyangli
Browse files

v0.1.4 ready to release


Co-authored-by: default avatarZhekai Zhang <sxtyzhangzk@gmail.com>
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
Co-authored-by: default avatarYujun Lin <16437040+synxlin@users.noreply.github.com>
parent d9cd6858
......@@ -607,14 +607,22 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states };
}
FluxModel::FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device) {
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : offload(offload) {
for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
if (offload && i > 0) { // don't offload first block
transformer_blocks.back()->setLazyLoad(true);
transformer_blocks.back()->releaseLazyParams();
}
}
for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
if (offload) {
single_transformer_blocks.back()->setLazyLoad(true);
single_transformer_blocks.back()->releaseLazyParams();
}
}
}
......@@ -626,22 +634,51 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
for (auto &&block : transformer_blocks) {
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
}
const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
// txt first, same as diffusers
Tensor concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
}
hidden_states = concat;
encoder_hidden_states = {};
Tensor concat;
for (auto &&block : single_transformer_blocks) {
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
}
auto compute = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
} else {
if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers
concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
}
hidden_states = concat;
encoder_hidden_states = {};
}
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
}
};
auto load = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->loadLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->loadLazyParams();
}
};
auto unload = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->releaseLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->releaseLazyParams();
}
};
LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
helper.run();
return hidden_states;
}
\ No newline at end of file
......@@ -128,10 +128,13 @@ private:
class FluxModel : public Module {
public:
FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device);
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
private:
bool offload;
};
\ No newline at end of file
......@@ -16,7 +16,7 @@ GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::Sca
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
registerParams
(weight, "weight")
(weight, "weight", ParamFlags::LazyLoad)
(bias, "bias")
;
}
......@@ -27,7 +27,7 @@ Tensor GEMM_F16::forward(Tensor x) {
}
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f)
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f), device(device)
{
this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
......@@ -39,7 +39,7 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(wzeros, "wzeros")
(bias, "bias")
......@@ -52,7 +52,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
if (key == "lora_down") {
const int new_rank = dst.shape[0];
this->lora_rank = new_rank;
......@@ -100,7 +100,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
use_fp4(use_fp4),
lora_rank(0), dtype(dtype)
lora_rank(0), dtype(dtype), device(device)
{
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
if (use_fp4) {
......@@ -124,7 +124,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
......@@ -143,7 +143,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
this->lora_rank = dst.shape[1];
this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
} else {
......@@ -152,7 +152,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (key == "wcscales") {
assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad);
dst = src.copy(this->qweight.device());
dst = src.copy(this->device);
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
......@@ -242,15 +242,15 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// shape[-1] = out_features;
auto shape = TensorShape(qact.actShape.dataExtent);
shape[-1] = out_features;
out = Tensor::allocate(shape, dtype, qweight.device());
out = Tensor::allocate(shape, dtype, device);
} else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device());
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) {
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device());
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device);
}
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape;
......@@ -363,13 +363,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
// shape[-1] = in_features / 2;
QuantizedActivation qact;
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device());
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) {
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device());
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device);
}
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device());
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent;
......@@ -420,7 +420,7 @@ GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::Scala
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
registerParams
(qweight, "qweight")
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
;
......
......@@ -36,6 +36,7 @@ public:
int lora_rank;
float lora_scale;
const Device device;
public:
Tensor qweight;
Tensor wscales;
......@@ -86,6 +87,7 @@ public:
std::vector<float> lora_scales; // every 16 ranks share a scale
const Tensor::ScalarType dtype;
const Device device;
protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) override;
......
This diff is collapsed.
This diff is collapsed.
......@@ -44,6 +44,7 @@ private:
class MMapImpl;
class MMapImplMio;
class MMapImplPrivate;
class MMapImplRead;
struct TensorInfo {
TensorShape shape;
......@@ -54,4 +55,6 @@ private:
};
std::map<std::string, TensorInfo> tensors;
std::unique_ptr<MMapImpl> mapped;
bool hostRegistered, memoryPinned;
};
\ No newline at end of file
......@@ -85,14 +85,15 @@ public:
if (size == 0) {
this->ptr = nullptr;
}
checkCUDA(cudaMallocAsync(&this->ptr, size, 0)); // use default stream to sync with all other streams
// TODO: buffer used in multiple streams?
checkCUDA(cudaMallocAsync(&this->ptr, size, getCurrentCUDAStream()));
}
virtual ~BufferCUDA() {
if (this->size == 0) {
assert(!this->ptr);
return;
}
checkCUDA(cudaFreeAsync(this->ptr, 0));
checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
}
virtual bool isAsyncBuffer() override {
return true;
......@@ -361,7 +362,7 @@ public:
Tensor &zero_() {
assert(this->is_contiguous());
checkCUDA(cudaMemset(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size()));
checkCUDA(cudaMemsetAsync(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this;
}
Tensor &copy_(Tensor other) {
......
This diff is collapsed.
......@@ -307,7 +307,7 @@ Tensor gemv_awq(
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads>>>(
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
checkCUDA(cudaGetLastError());
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment