Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
...@@ -8,11 +8,9 @@ import numpy as np ...@@ -8,11 +8,9 @@ import numpy as np
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from diffusers.image_processor import PipelineImageInput from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, EXAMPLE_DOC_STRING, retrieve_timesteps from diffusers.pipelines.flux.pipeline_flux import EXAMPLE_DOC_STRING, calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import ( from diffusers.utils import replace_example_docstring
replace_example_docstring,
)
from facexlib.parsing import init_parsing_model from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
......
[build-system] [tool.isort]
requires = [ profile = "black"
"setuptools", known_first_party = ["nunchaku"]
"torch>=2.5", line_length = 120
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find] [tool.black]
include = ["nunchaku"] line-length = 120
target-version = ['py311']
[tool.ruff] [tool.ruff]
line-length = 140 line-length = 120
[tool.ruff.lint]
select = ["E", "W", "F"]
ignore = ["F401"]
[project] [project]
dynamic = ["version"] dynamic = ["version"]
...@@ -29,3 +22,15 @@ dependencies = [ ...@@ -29,3 +22,15 @@ dependencies = [
"huggingface_hub", "huggingface_hub",
] ]
requires-python = ">=3.10" requires-python = ">=3.10"
[build-system]
requires = [
"setuptools",
"torch>=2.5",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["nunchaku"]
...@@ -6,7 +6,7 @@ import sys ...@@ -6,7 +6,7 @@ import sys
import setuptools import setuptools
import torch import torch
from packaging import version as packaging_version from packaging import version as packaging_version
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension): class CustomBuildExtension(BuildExtension):
......
This diff is collapsed.
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "layernorm.h" #include "layernorm.h"
#include <pybind11/functional.h> #include <pybind11/functional.h>
namespace pybind11 { namespace pybind11 {
class function; class function;
} }
enum class AttentionImpl { enum class AttentionImpl {
...@@ -49,6 +49,7 @@ public: ...@@ -49,6 +49,7 @@ public:
Tensor scale_mlp; Tensor scale_mlp;
Tensor gate_mlp; Tensor gate_mlp;
}; };
public: public:
AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device); AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device);
Output forward(Tensor x, Tensor emb); Output forward(Tensor x, Tensor emb);
...@@ -87,7 +88,13 @@ public: ...@@ -87,7 +88,13 @@ public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device); FluxSingleTransformerBlock(int dim,
int num_attention_heads,
int attention_head_dim,
int mlp_ratio,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb); Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);
public: public:
...@@ -113,8 +120,19 @@ public: ...@@ -113,8 +120,19 @@ public:
static constexpr bool USE_4BIT = true; static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>; using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device); JointTransformerBlock(int dim,
std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio); int num_attention_heads,
int attention_head_dim,
bool context_pre_only,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio);
public: public:
const int dim; const int dim;
...@@ -143,8 +161,7 @@ private: ...@@ -143,8 +161,7 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward( Tensor forward(Tensor hidden_states,
Tensor hidden_states,
Tensor encoder_hidden_states, Tensor encoder_hidden_states,
Tensor temb, Tensor temb,
Tensor rotary_emb_img, Tensor rotary_emb_img,
...@@ -153,8 +170,7 @@ public: ...@@ -153,8 +170,7 @@ public:
Tensor controlnet_block_samples, Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples, Tensor controlnet_single_block_samples,
bool skip_first_layer = false); bool skip_first_layer = false);
std::tuple<Tensor, Tensor> forward_layer( std::tuple<Tensor, Tensor> forward_layer(size_t layer,
size_t layer,
Tensor hidden_states, Tensor hidden_states,
Tensor encoder_hidden_states, Tensor encoder_hidden_states,
Tensor temb, Tensor temb,
...@@ -164,14 +180,16 @@ public: ...@@ -164,14 +180,16 @@ public:
Tensor controlnet_single_block_samples); Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl); void setAttentionImpl(AttentionImpl impl);
void set_residual_callback(std::function<Tensor(const Tensor&)> cb); void set_residual_callback(std::function<Tensor(const Tensor &)> cb);
public: public:
const Tensor::ScalarType dtype; const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks; std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks; std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
std::function<Tensor(const Tensor&)> residual_callback; std::function<Tensor(const Tensor &)> residual_callback;
private: private:
bool offload; bool offload;
}; };
...@@ -9,16 +9,12 @@ ...@@ -9,16 +9,12 @@
using namespace nunchaku; using namespace nunchaku;
GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) : GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device)
in_features(in_features), out_features(out_features) : in_features(in_features), out_features(out_features) {
{
this->weight = Tensor::allocate({out_features, in_features}, dtype, device); this->weight = Tensor::allocate({out_features, in_features}, dtype, device);
this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{}; this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
registerParams registerParams(weight, "weight", ParamFlags::LazyLoad)(bias, "bias");
(weight, "weight", ParamFlags::LazyLoad)
(bias, "bias")
;
} }
Tensor GEMM_F16::forward(Tensor x) { Tensor GEMM_F16::forward(Tensor x) {
...@@ -26,9 +22,9 @@ Tensor GEMM_F16::forward(Tensor x) { ...@@ -26,9 +22,9 @@ Tensor GEMM_F16::forward(Tensor x) {
return out; return out;
} }
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) : GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device)
in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f), device(device) : in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f),
{ device(device) {
this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device); this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device); this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
this->wzeros = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device); this->wzeros = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
...@@ -38,14 +34,8 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca ...@@ -38,14 +34,8 @@ GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::Sca
this->lora_down = Tensor::allocate({lora_rank, in_features}, dtype, device, true); this->lora_down = Tensor::allocate({lora_rank, in_features}, dtype, device, true);
this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true); this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);
registerParams registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(wzeros, "wzeros")(bias, "bias")(
(qweight, "qweight", ParamFlags::LazyLoad) lora_down, "lora_down", ParamFlags::Optional)(lora_up, "lora_up", ParamFlags::Optional);
(wscales, "wscales")
(wzeros, "wzeros")
(bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional)
;
} }
void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) { void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
...@@ -95,15 +85,12 @@ Tensor GEMV_AWQ::forward(Tensor x) { ...@@ -95,15 +85,12 @@ Tensor GEMV_AWQ::forward(Tensor x) {
return out; return out;
} }
#define NO_LORA_FUSION 0 #define NO_LORA_FUSION 0
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device) : GEMM_W4A4::GEMM_W4A4(
in_features(in_features), out_features(out_features), int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device)
in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128), : in_features(in_features), out_features(out_features), in_features_pad(ceilDiv(in_features, 128) * 128),
use_fp4(use_fp4), out_features_pad(ceilDiv(out_features, 128) * 128), use_fp4(use_fp4), lora_rank(0), dtype(dtype), device(device) {
lora_rank(0), dtype(dtype), device(device)
{
this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true); this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
if (use_fp4) { if (use_fp4) {
this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true); this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true);
...@@ -125,16 +112,9 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, ...@@ -125,16 +112,9 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4,
this->wcscales = Tensor::allocate({0}, dtype, device, true); this->wcscales = Tensor::allocate({0}, dtype, device, true);
registerParams registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(this->bias, "bias")(
(qweight, "qweight", ParamFlags::LazyLoad) lora_down, "lora_down", ParamFlags::Optional)(lora_up, "lora_up", ParamFlags::Optional)(smooth, "smooth")(
(wscales, "wscales") wtscale, "wtscale", ParamFlags::Optional)(wcscales, "wcscales", ParamFlags::Optional);
(this->bias, "bias")
(lora_down, "lora_down", ParamFlags::Optional)
(lora_up, "lora_up", ParamFlags::Optional)
(smooth, "smooth")
(wtscale, "wtscale", ParamFlags::Optional)
(wcscales, "wcscales", ParamFlags::Optional)
;
#if NO_LORA_FUSION #if NO_LORA_FUSION
checkCUBLAS(cublasCreate(&handle)); checkCUBLAS(cublasCreate(&handle));
...@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) { ...@@ -181,11 +161,21 @@ Tensor GEMM_W4A4::forward_silu(Tensor x) {
return std::get<Tensor>(this->forward(x, FuseOptions::SILU, nullptr)); return std::get<Tensor>(this->forward(x, FuseOptions::SILU, nullptr));
} }
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) { std::variant<Tensor, GEMM_W4A4::QuantizedActivation>
GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
return forward_quant(quantize(x, false), fuse, nextGEMM); return forward_quant(quantize(x, false), fuse, nextGEMM);
} }
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) { void GEMM_W4A4::forward(Tensor x,
Tensor out,
Tensor pool,
Tensor norm_q,
Tensor norm_k,
Tensor rotary_emb,
Tensor out_q,
Tensor out_k,
Tensor out_v,
int numTokens) {
QuantizedActivation qact = quantize(x, false); QuantizedActivation qact = quantize(x, false);
#if !NO_LORA_FUSION #if !NO_LORA_FUSION
...@@ -198,17 +188,59 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor ...@@ -198,17 +188,59 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out); debug("gemm.nolora.out", out);
#endif #endif
kernels::gemm_w4a4( kernels::gemm_w4a4(qact.act,
qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false, qweight,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}, out,
out_q, out_k, out_v, numTokens {},
); qact.ascales,
wscales,
{},
pool,
qact.lora_act,
this->lora_up,
{},
{},
norm_q,
norm_k,
rotary_emb,
this->bias,
{},
{},
{},
qact.is_unsigned,
this->lora_scales,
false,
use_fp4,
*this->wtscale.data_ptr<float>(),
wcscales.numel() > 0 ? wcscales : Tensor{},
out_q,
out_k,
out_v,
numTokens);
debug("gemm.out", out); debug("gemm.out", out);
#else #else
const int M = (int)qact.act.numel() / qact.act.shape[-1]; const int M = (int)qact.act.numel() / qact.act.shape[-1];
kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, {}, {}, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned, this->lora_scales); kernels::gemm_w4a4(qact.act,
qweight,
out,
{},
qact.ascales,
wscales,
{},
pool,
{},
{},
{},
{},
norm_q,
norm_k,
rotary_emb,
this->bias,
{},
qact.is_unsigned,
this->lora_scales);
nvtxRangePushA("LoraUp"); nvtxRangePushA("LoraUp");
...@@ -216,10 +248,12 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor ...@@ -216,10 +248,12 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
static const half zero = 0.0; static const half zero = 0.0;
// lora_up: [M, R] * [OC, R] => [M, OC] // lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T // cublas view: [OC, R] * [M, R]^T
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_T,
CUBLAS_OP_T, CUBLAS_OP_N, CUBLAS_OP_N,
this->out_features, M, this->lora_rank, this->out_features,
M,
this->lora_rank,
&one, &one,
this->lora_up.data_ptr<half>(), this->lora_up.data_ptr<half>(),
this->lora_rank, this->lora_rank,
...@@ -233,7 +267,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor ...@@ -233,7 +267,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
#endif #endif
} }
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM) { std::variant<Tensor, GEMM_W4A4::QuantizedActivation>
GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
Tensor out; Tensor out;
QuantizedActivation qout; QuantizedActivation qout;
...@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -280,11 +315,35 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
} }
#endif #endif
kernels::gemm_w4a4( kernels::gemm_w4a4(qact.act,
qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU, qweight,
use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{}, out,
{}, {}, {}, 0 qout.act,
); qact.ascales,
wscales,
qout.ascales,
{},
qact.lora_act,
this->lora_up,
next_lora,
qout.lora_act,
{},
{},
{},
this->bias,
next_smooth,
{},
{},
qact.is_unsigned,
this->lora_scales,
fuse == FuseOptions::SILU,
use_fp4,
*this->wtscale.data_ptr<float>(),
wcscales.numel() > 0 ? wcscales : Tensor{},
{},
{},
{},
0);
if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) { if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
debug("gemm.out", out); debug("gemm.out", out);
...@@ -294,7 +353,6 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -294,7 +353,6 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
debug("gemm.lora_act_out", qout.lora_act); debug("gemm.lora_act_out", qout.lora_act);
} }
#else #else
if (!out.valid()) { if (!out.valid()) {
auto shape = TensorShape(qact.act.shape.dataExtent); auto shape = TensorShape(qact.act.shape.dataExtent);
...@@ -302,7 +360,25 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -302,7 +360,25 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out = Tensor::allocate(shape, Tensor::FP16, qweight.device()); out = Tensor::allocate(shape, Tensor::FP16, qweight.device());
} }
kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, {}, {}, {}, {}, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned, this->lora_scales); kernels::gemm_w4a4(qact.act,
qweight,
out,
qout.act,
qact.ascales,
wscales,
qout.ascales,
{},
{},
{},
{},
{},
{},
{},
{},
this->bias,
next_smooth,
qact.is_unsigned,
this->lora_scales);
nvtxRangePushA("LoraUp"); nvtxRangePushA("LoraUp");
...@@ -312,10 +388,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -312,10 +388,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// lora_up: [M, R] * [OC, R]^T => [M, OC] // lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M] // cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong? // lora_up layout wrong?
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_T,
CUBLAS_OP_T, CUBLAS_OP_N, CUBLAS_OP_N,
this->out_features, M, this->lora_rank, this->out_features,
M,
this->lora_rank,
&one, &one,
this->lora_up.data_ptr<half>(), this->lora_up.data_ptr<half>(),
this->lora_rank, this->lora_rank,
...@@ -332,10 +410,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu ...@@ -332,10 +410,12 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
// IC is for next lora (OC of this layer) // IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R] // lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M] // cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N, CUBLAS_OP_N,
this->lora_rank, M, this->out_features, this->lora_rank,
M,
this->out_features,
&one, &one,
next_lora.data_ptr<half>(), next_lora.data_ptr<half>(),
this->lora_rank, this->lora_rank,
...@@ -383,7 +463,8 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -383,7 +463,8 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug("quantize.x", x); debug("quantize.x", x);
debug("quantize.smooth", this->smooth); debug("quantize.smooth", this->smooth);
kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4); kernels::quantize_w4a4_act_fuse_lora(
x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4);
debug("quantize.qact", qact.act); debug("quantize.qact", qact.act);
debug("quantize.ascales", qact.ascales); debug("quantize.ascales", qact.ascales);
...@@ -396,10 +477,12 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -396,10 +477,12 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
// lora_down: [M, IC] * [IC, R] => [M, R] // lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] // cublas view: [R, IC] * [IC, M]
checkCUBLAS(cublasHgemm( checkCUBLAS(cublasHgemm(handle,
handle, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N, CUBLAS_OP_N,
this->lora_rank, M, this->in_features, this->lora_rank,
M,
this->in_features,
&one, &one,
lora_down.data_ptr<half>(), lora_down.data_ptr<half>(),
this->lora_rank, this->lora_rank,
...@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) { ...@@ -418,18 +501,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
return qact; return qact;
} }
GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) : GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device)
in_features(in_features), out_features(out_features), dtype(dtype) : in_features(in_features), out_features(out_features), dtype(dtype) {
{
this->qweight = Tensor::allocate({out_features, in_features}, Tensor::INT8, device); this->qweight = Tensor::allocate({out_features, in_features}, Tensor::INT8, device);
this->wscales = Tensor::allocate({out_features}, dtype, device); this->wscales = Tensor::allocate({out_features}, dtype, device);
this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{}; this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
registerParams registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(this->bias, "bias");
(qweight, "qweight", ParamFlags::LazyLoad)
(wscales, "wscales")
(this->bias, "bias")
;
} }
GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) { GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
...@@ -461,16 +539,11 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) { ...@@ -461,16 +539,11 @@ Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
return out; return out;
} }
DWCONV::DWCONV(int in_features, bool use_bias, Tensor::ScalarType dtype, Device device) : DWCONV::DWCONV(int in_features, bool use_bias, Tensor::ScalarType dtype, Device device) : in_features(in_features) {
in_features(in_features)
{
this->weight = Tensor::allocate({in_features, 3, 3, 1}, dtype, device); this->weight = Tensor::allocate({in_features, 3, 3, 1}, dtype, device);
this->bias = use_bias ? Tensor::allocate({in_features}, dtype, device) : Tensor{}; this->bias = use_bias ? Tensor::allocate({in_features}, dtype, device) : Tensor{};
registerParams registerParams(this->weight, "weight")(this->bias, "bias");
(this->weight, "weight")
(this->bias, "bias")
;
} }
Tensor DWCONV::forward(Tensor x) { Tensor DWCONV::forward(Tensor x) {
......
...@@ -37,6 +37,7 @@ public: ...@@ -37,6 +37,7 @@ public:
float lora_scale; float lora_scale;
const Device device; const Device device;
public: public:
Tensor qweight; Tensor qweight;
Tensor wscales; Tensor wscales;
...@@ -69,12 +70,18 @@ public: ...@@ -69,12 +70,18 @@ public:
Tensor forward(Tensor x); Tensor forward(Tensor x);
Tensor forward_silu(Tensor x); Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
void forward( void forward(Tensor x,
Tensor x, Tensor out, Tensor out,
Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {}, Tensor pool = {},
Tensor out_q = {}, Tensor out_k = {}, Tensor out_v = {}, int numTokens = 0 Tensor norm_q = {},
); Tensor norm_k = {},
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr); Tensor rotary_emb = {},
Tensor out_q = {},
Tensor out_k = {},
Tensor out_v = {},
int numTokens = 0);
std::variant<Tensor, QuantizedActivation>
forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
Tensor forward_quant(QuantizedActivation qact); Tensor forward_quant(QuantizedActivation qact);
public: public:
...@@ -118,13 +125,16 @@ public: ...@@ -118,13 +125,16 @@ public:
Tensor act; Tensor act;
Tensor ascales; Tensor ascales;
}; };
public: public:
GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device); GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device);
public: public:
QuantizedActivation quantize(Tensor x, bool fuse_glu); QuantizedActivation quantize(Tensor x, bool fuse_glu);
Tensor forward_quant(QuantizedActivation qact); Tensor forward_quant(QuantizedActivation qact);
Tensor forward(Tensor x) { return forward_quant(quantize(x, false)); } Tensor forward(Tensor x) {
return forward_quant(quantize(x, false));
}
public: public:
const int in_features; const int in_features;
......
...@@ -108,7 +108,8 @@ public: ...@@ -108,7 +108,8 @@ public:
dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device); dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device);
if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) { if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) {
throw std::runtime_error(spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key)); throw std::runtime_error(
spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key));
} }
m->loadParam(key, dst, src); m->loadParam(key, dst, src);
} }
...@@ -127,14 +128,10 @@ public: ...@@ -127,14 +128,10 @@ public:
}); });
} }
void setLazyLoad(bool val) { void setLazyLoad(bool val) {
traverse([val](Module *m) { traverse([val](Module *m) { m->enabledLazyLoad = val; });
m->enabledLazyLoad = val;
});
} }
void setAutoCastFP16(bool val) { void setAutoCastFP16(bool val) {
traverse([val](Module *m) { traverse([val](Module *m) { m->enabledAutoCastFP16 = val; });
m->enabledAutoCastFP16 = val;
});
} }
protected: protected:
...@@ -143,7 +140,8 @@ protected: ...@@ -143,7 +140,8 @@ protected:
Tensor::FP16, Tensor::FP16,
Tensor::BF16, Tensor::BF16,
}; };
if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) && whitelist.contains(src.scalar_type())) { if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) &&
whitelist.contains(src.scalar_type())) {
copyWithCast(dst, src); copyWithCast(dst, src);
} else { } else {
dst.copy_(src); dst.copy_(src);
...@@ -227,8 +225,7 @@ struct LayerOffloadHelper { ...@@ -227,8 +225,7 @@ struct LayerOffloadHelper {
std::unique_ptr<CUDAEventWrapper> eventLoadDone; std::unique_ptr<CUDAEventWrapper> eventLoadDone;
LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload) LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload)
: offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload) : offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload) {
{
if (offload) { if (offload) {
streamCompute = std::make_unique<CUDAStreamWrapper>(); streamCompute = std::make_unique<CUDAStreamWrapper>();
streamLoad = std::make_unique<CUDAStreamWrapper>(); streamLoad = std::make_unique<CUDAStreamWrapper>();
...@@ -305,11 +302,11 @@ private: ...@@ -305,11 +302,11 @@ private:
} }
} }
#ifdef _WIN32 #ifdef _WIN32
return true; return true;
#else #else
return false; return false;
#endif #endif
} }
void workaroundFlush() { void workaroundFlush() {
if (!needWorkaround) { if (!needWorkaround) {
......
...@@ -10,18 +10,11 @@ ...@@ -10,18 +10,11 @@
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
using namespace nunchaku; using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device)
dim(dim), : dim(dim), dim_pad(ceilDiv(dim, 128) * 128), qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device),
dim_pad(ceilDiv(dim, 128) * 128), out_proj(dim_pad, dim, bias, use_fp4, dtype, device), pag_to_v(std::nullopt) {
qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device), registerChildren(qkv_proj, "qkv_proj")(out_proj, "out_proj");
out_proj(dim_pad, dim, bias, use_fp4, dtype, device),
pag_to_v(std::nullopt)
{
registerChildren
(qkv_proj, "qkv_proj")
(out_proj, "out_proj")
;
if (pag) { if (pag) {
pag_to_v.emplace(dim, dim_pad, bias, use_fp4, dtype, device); pag_to_v.emplace(dim, dim_pad, bias, use_fp4, dtype, device);
...@@ -57,21 +50,35 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -57,21 +50,35 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device()); Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device());
Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device()); Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device());
kernels::gemm_w4a4( kernels::gemm_w4a4(qact.act,
qact.act,
qkv_proj.qweight, qkv_proj.qweight,
{}, {},
{}, {},
qact.ascales, qact.ascales,
qkv_proj.wscales, qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {}, {},
vk, q, {},
qact.is_unsigned, qkv_proj.lora_scales, false, qact.lora_act,
qkv_proj.lora_up,
{},
{},
{},
{},
{},
qkv_proj.bias,
{},
vk,
q,
qact.is_unsigned,
qkv_proj.lora_scales,
false,
qkv_proj.use_fp4, qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(), *qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{}, qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{},
{}, {}, {}, 0 {},
); {},
{},
0);
debug("vk", vk); debug("vk", vk);
debug("q", q); debug("q", q);
...@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
q = q_unpad; q = q_unpad;
} }
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales); // kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q); // return out_proj.forward(q);
...@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) { ...@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return out; return out;
} }
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) : MultiHeadCrossAttention::MultiHeadCrossAttention(
num_heads(num_heads), head_dim(head_dim), int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device)
: num_heads(num_heads), head_dim(head_dim),
q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device), q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device), kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device) out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device) {
{ registerChildren(q_linear, "q_linear")(kv_linear, "kv_linear")(out_proj, "out_proj");
registerChildren
(q_linear, "q_linear")
(kv_linear, "kv_linear")
(out_proj, "out_proj")
;
} }
Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt) { Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt) {
...@@ -161,16 +163,22 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -161,16 +163,22 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
Tensor k = kv.slice(1, 0, num_heads); Tensor k = kv.slice(1, 0, num_heads);
Tensor v = kv.slice(1, num_heads, num_heads * 2); Tensor v = kv.slice(1, num_heads, num_heads * 2);
Tensor attn_output = mha_varlen_fwd( Tensor attn_output = mha_varlen_fwd(q,
q, k, v, k,
cu_seqlens_img, cu_seqlens_txt, v,
num_tokens_img, num_tokens_txt, cu_seqlens_img,
cu_seqlens_txt,
num_tokens_img,
num_tokens_txt,
0.0f, 0.0f,
pow(q.shape[-1], (-0.5)), pow(q.shape[-1], (-0.5)),
false, false, false,
-1, -1, false,
false -1,
).front().view({batch_size, num_tokens_img, num_heads * head_dim}); -1,
false)
.front()
.view({batch_size, num_tokens_img, num_heads * head_dim});
// Tensor attn_output = mha_fwd(q, k, v, // Tensor attn_output = mha_fwd(q, k, v,
// 0.0f, // 0.0f,
...@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return out_proj.forward(attn_output); return out_proj.forward(attn_output);
} }
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaGLUMBConv::SanaGLUMBConv(
in_features(in_features), hidden_features(hidden_features), int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device)
: in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device), inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device), depth_conv(hidden_features * 2, true, dtype, device),
point_conv(hidden_features, in_features, false, use_fp4, dtype, device) point_conv(hidden_features, in_features, false, use_fp4, dtype, device) {
{ registerChildren(inverted_conv, "inverted_conv")(depth_conv, "depth_conv")(point_conv, "point_conv");
registerChildren
(inverted_conv, "inverted_conv")
(depth_conv, "depth_conv")
(point_conv, "point_conv")
;
} }
Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) { Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
...@@ -208,28 +212,34 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) { ...@@ -208,28 +212,34 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return point_conv.forward_quant(qact); return point_conv.forward_quant(qact);
} }
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size,
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads), int intermediate_size,
int num_cross_attention_heads,
bool pag,
bool use_fp4,
Tensor::ScalarType dtype,
Device device)
: hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, use_fp4, dtype, device), attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device), cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
ff(hidden_size, intermediate_size, use_fp4, dtype, device), ff(hidden_size, intermediate_size, use_fp4, dtype, device), norm1(hidden_size, 1e-6, false, dtype, device),
norm1(hidden_size, 1e-6, false, dtype, device), norm2(hidden_size, 1e-6, false, dtype, device) {
norm2(hidden_size, 1e-6, false, dtype, device)
{
this->scale_shift_table = Tensor::allocate({6, hidden_size}, dtype, device); this->scale_shift_table = Tensor::allocate({6, hidden_size}, dtype, device);
registerChildren registerChildren(attn, "attn")(cross_attn, "cross_attn")(ff, "ff");
(attn, "attn")
(cross_attn, "cross_attn")
(ff, "ff")
;
registerParams registerParams(this->scale_shift_table, "scale_shift_table");
(this->scale_shift_table, "scale_shift_table")
;
} }
Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) { Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg) {
nvtxRangePushA("SanaLinearTransformerBlock"); nvtxRangePushA("SanaLinearTransformerBlock");
...@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return hidden_states; return hidden_states;
} }
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) : SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) : config(config) {
config(config)
{
const int inner_dim = config.num_attention_heads * config.attention_head_dim; const int inner_dim = config.num_attention_heads * config.attention_head_dim;
for (int i = 0; i < config.num_layers; i++) { for (int i = 0; i < config.num_layers; i++) {
transformer_blocks.push_back(std::make_unique<SanaLinearTransformerBlock>( transformer_blocks.push_back(std::make_unique<SanaLinearTransformerBlock>(
...@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) ...@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
config.num_cross_attention_heads, config.num_cross_attention_heads,
std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(), std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
config.use_fp4, config.use_fp4,
dtype, device dtype,
)); device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i)); registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
} }
} }
Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer) { Tensor SanaModel::forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer) {
for (int i = (skip_first_layer ? 1 : 0); i < config.num_layers; i++) { for (int i = (skip_first_layer ? 1 : 0); i < config.num_layers; i++) {
auto &&block = transformer_blocks[i]; auto &&block = transformer_blocks[i];
hidden_states = block->forward( hidden_states = block->forward(hidden_states,
hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W, encoder_hidden_states,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(), timestep,
cfg cu_seqlens_img,
); cu_seqlens_txt,
H,
W,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) !=
config.pag_layers.end(),
cfg);
} }
return hidden_states; return hidden_states;
} }
...@@ -57,9 +57,23 @@ private: ...@@ -57,9 +57,23 @@ private:
class SanaLinearTransformerBlock : public Module { class SanaLinearTransformerBlock : public Module {
public: public:
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device); SanaLinearTransformerBlock(int hidden_size,
int intermediate_size,
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg); int num_cross_attention_heads,
bool pag,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
Tensor forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg);
public: public:
const int hidden_size; const int hidden_size;
...@@ -89,7 +103,16 @@ struct SanaConfig { ...@@ -89,7 +103,16 @@ struct SanaConfig {
class SanaModel : public Module { class SanaModel : public Module {
public: public:
SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device); SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer); Tensor forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer);
public: public:
const SanaConfig config; const SanaConfig config;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment