Commit 873a35be authored by muyangli's avatar muyangli
Browse files

v0.1.4 ready to release


Co-authored-by: default avatarZhekai Zhang <sxtyzhangzk@gmail.com>
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
Co-authored-by: default avatarYujun Lin <16437040+synxlin@users.noreply.github.com>
parent d9cd6858
...@@ -607,14 +607,22 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -607,14 +607,22 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return { hidden_states, encoder_hidden_states }; return { hidden_states, encoder_hidden_states };
} }
FluxModel::FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device) { FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : offload(offload) {
for (int i = 0; i < 19; i++) { for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device)); transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
if (offload && i > 0) { // don't offload first block
transformer_blocks.back()->setLazyLoad(true);
transformer_blocks.back()->releaseLazyParams();
}
} }
for (int i = 0; i < 38; i++) { for (int i = 0; i < 38; i++) {
single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda())); single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i)); registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
if (offload) {
single_transformer_blocks.back()->setLazyLoad(true);
single_transformer_blocks.back()->releaseLazyParams();
}
} }
} }
...@@ -626,22 +634,51 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -626,22 +634,51 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const int txt_tokens = encoder_hidden_states.shape[1]; const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1]; const int img_tokens = hidden_states.shape[1];
for (auto &&block : transformer_blocks) { const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
}
Tensor concat;
auto compute = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
} else {
if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers // txt first, same as diffusers
Tensor concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device); concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states); concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states); concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
} }
hidden_states = concat; hidden_states = concat;
encoder_hidden_states = {}; encoder_hidden_states = {};
}
for (auto &&block : single_transformer_blocks) { auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single); hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
} }
};
auto load = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->loadLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->loadLazyParams();
}
};
auto unload = [&](int layer) {
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
block->releaseLazyParams();
} else {
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
block->releaseLazyParams();
}
};
LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
helper.run();
return hidden_states; return hidden_states;
} }
\ No newline at end of file
...@@ -128,10 +128,13 @@ private: ...@@ -128,10 +128,13 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(bool use_fp4, Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single); Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single);
public: public:
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks; std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks; std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
private:
bool offload;
}; };
\ No newline at end of file
...@@ -16,7 +16,7 @@ GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::Sca ...@@ -16,7 +16,7 @@ GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::Sca
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{}; this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
registerParams registerParams
(weight, "weight") (weight, "weight", ParamFlags::LazyLoad)
(bias, "bias") (bias, "bias")
; ;
} }
...@@ -27,7 +27,7 @@ Tensor GEMM_F16::forward(Tensor x) { ...@@ -27,7 +27,7 @@ Tensor GEMM_F16::forward(Tensor x) {
} }
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) : GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f) in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f), device(device)
{ {
this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device); this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device); this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
...@@ -39,7 +39,7 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca ...@@ -39,7 +39,7 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true); this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
registerParams registerParams
(qweight, "qweight") (qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales") (wscales, "wscales")
(wzeros, "wzeros") (wzeros, "wzeros")
(bias, "bias") (bias, "bias")
...@@ -52,7 +52,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -52,7 +52,7 @@ void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") { if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2); assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) { if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device()); dst = src.copy(this->device);
if (key == "lora_down") { if (key == "lora_down") {
const int new_rank = dst.shape[0]; const int new_rank = dst.shape[0];
this->lora_rank = new_rank; this->lora_rank = new_rank;
...@@ -100,7 +100,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, ...@@ -100,7 +100,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
in_features(in_features), out_features(out_features), in_features(in_features), out_features(out_features),
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128), in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
use_fp4(use_fp4), use_fp4(use_fp4),
lora_rank(0), dtype(dtype) lora_rank(0), dtype(dtype), device(device)
{ {
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true); this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
if (use_fp4) { if (use_fp4) {
...@@ -124,7 +124,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, ...@@ -124,7 +124,7 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this->wcscales = Tensor::allocate({0}, dtype, device, true); this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams registerParams
(qweight, "qweight") (qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales") (wscales, "wscales")
(this->bias, "bias") (this->bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional) (lora_down, "lora_down", ParamFlags::Optional)
...@@ -143,7 +143,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -143,7 +143,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
if (key == "lora_down" || key == "lora_up") { if (key == "lora_down" || key == "lora_up") {
assert(src.ndims() == 2); assert(src.ndims() == 2);
if (dst.shape.dataExtent != src.shape.dataExtent) { if (dst.shape.dataExtent != src.shape.dataExtent) {
dst = src.copy(this->qweight.device()); dst = src.copy(this->device);
this->lora_rank = dst.shape[1]; this->lora_rank = dst.shape[1];
this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f); this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
} else { } else {
...@@ -152,7 +152,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) { ...@@ -152,7 +152,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (key == "wcscales") { } else if (key == "wcscales") {
assert(src.ndims() == 1); assert(src.ndims() == 1);
assert(src.shape[0] == out_features_pad); assert(src.shape[0] == out_features_pad);
dst = src.copy(this->qweight.device()); dst = src.copy(this->device);
} else if (key == "wtscale") { } else if (key == "wtscale") {
assert(src.numel() == 1); assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) { if (src.dtype() == Tensor::BF16) {
...@@ -242,15 +242,15 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -242,15 +242,15 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// shape[-1] = out_features; // shape[-1] = out_features;
auto shape = TensorShape(qact.actShape.dataExtent); auto shape = TensorShape(qact.actShape.dataExtent);
shape[-1] = out_features; shape[-1] = out_features;
out = Tensor::allocate(shape, dtype, qweight.device()); out = Tensor::allocate(shape, dtype, device);
} else { } else {
qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, qweight.device()); qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) { if (use_fp4) {
qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device()); qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else { } else {
qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, qweight.device()); qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device);
} }
qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device()); qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qout.is_unsigned = !use_fp4; qout.is_unsigned = !use_fp4;
qout.actShape = qact.actShape; qout.actShape = qact.actShape;
...@@ -363,13 +363,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -363,13 +363,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
// shape[-1] = in_features / 2; // shape[-1] = in_features / 2;
QuantizedActivation qact; QuantizedActivation qact;
qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, qweight.device()); qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, device);
if (use_fp4) { if (use_fp4) {
qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, qweight.device()); qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, device);
} else { } else {
qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, qweight.device()); qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device);
} }
qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, qweight.device()); qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
qact.is_unsigned = false; qact.is_unsigned = false;
qact.actShape = x.shape.dataExtent; qact.actShape = x.shape.dataExtent;
...@@ -420,7 +420,7 @@ GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::Scala ...@@ -420,7 +420,7 @@ GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::Scala
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{}; this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
registerParams registerParams
(qweight, "qweight") (qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales") (wscales, "wscales")
(this->bias, "bias") (this->bias, "bias")
; ;
......
...@@ -36,6 +36,7 @@ public: ...@@ -36,6 +36,7 @@ public:
int lora_rank; int lora_rank;
float lora_scale; float lora_scale;
const Device device;
public: public:
Tensor qweight; Tensor qweight;
Tensor wscales; Tensor wscales;
...@@ -86,6 +87,7 @@ public: ...@@ -86,6 +87,7 @@ public:
std::vector<float> lora_scales; // every 16 ranks share a scale std::vector<float> lora_scales; // every 16 ranks share a scale
const Tensor::ScalarType dtype; const Tensor::ScalarType dtype;
const Device device;
protected: protected:
virtual void loadParam(std::string key, Tensor &dst, Tensor src) override; virtual void loadParam(std::string key, Tensor &dst, Tensor src) override;
......
This diff is collapsed.
This diff is collapsed.
...@@ -44,6 +44,7 @@ private: ...@@ -44,6 +44,7 @@ private:
class MMapImpl; class MMapImpl;
class MMapImplMio; class MMapImplMio;
class MMapImplPrivate; class MMapImplPrivate;
class MMapImplRead;
struct TensorInfo { struct TensorInfo {
TensorShape shape; TensorShape shape;
...@@ -54,4 +55,6 @@ private: ...@@ -54,4 +55,6 @@ private:
}; };
std::map<std::string, TensorInfo> tensors; std::map<std::string, TensorInfo> tensors;
std::unique_ptr<MMapImpl> mapped; std::unique_ptr<MMapImpl> mapped;
bool hostRegistered, memoryPinned;
}; };
\ No newline at end of file
...@@ -85,14 +85,15 @@ public: ...@@ -85,14 +85,15 @@ public:
if (size == 0) { if (size == 0) {
this->ptr = nullptr; this->ptr = nullptr;
} }
checkCUDA(cudaMallocAsync(&this->ptr, size, 0)); // use default stream to sync with all other streams // TODO: buffer used in multiple streams?
checkCUDA(cudaMallocAsync(&this->ptr, size, getCurrentCUDAStream()));
} }
virtual ~BufferCUDA() { virtual ~BufferCUDA() {
if (this->size == 0) { if (this->size == 0) {
assert(!this->ptr); assert(!this->ptr);
return; return;
} }
checkCUDA(cudaFreeAsync(this->ptr, 0)); checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
return true; return true;
...@@ -361,7 +362,7 @@ public: ...@@ -361,7 +362,7 @@ public:
Tensor &zero_() { Tensor &zero_() {
assert(this->is_contiguous()); assert(this->is_contiguous());
checkCUDA(cudaMemset(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size())); checkCUDA(cudaMemsetAsync(data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this; return *this;
} }
Tensor &copy_(Tensor other) { Tensor &copy_(Tensor other) {
......
This diff is collapsed.
...@@ -307,7 +307,7 @@ Tensor gemv_awq( ...@@ -307,7 +307,7 @@ Tensor gemv_awq(
return; return;
} }
if constexpr (M > 0) { if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads>>>( gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n in_feats, kernel, scaling_factors, zeros, out_feats, k, n
); );
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
......
...@@ -1440,7 +1440,7 @@ public: ...@@ -1440,7 +1440,7 @@ public:
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
__device__ __forceinline__ __device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon) { static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -1470,7 +1470,7 @@ public: ...@@ -1470,7 +1470,7 @@ public:
CHECK_NAN(fpsum, "fpsum"); CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, INT_MAX, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE { unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, maxRows - warpId * WARP_M, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope // load rope
pack_rope_t rope; pack_rope_t rope;
if (laneId < LANES_PER_HEAD) { if (laneId < LANES_PER_HEAD) {
...@@ -1605,7 +1605,8 @@ public: ...@@ -1605,7 +1605,8 @@ public:
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr, args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS), args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k, is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon args.epsilon,
args.actualM - bm * BLOCK_M
); );
} else { } else {
EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{ EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment