Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
......@@ -8,11 +8,9 @@ import numpy as np
import torch
from diffusers import FluxPipeline
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, EXAMPLE_DOC_STRING, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_flux import EXAMPLE_DOC_STRING, calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import (
replace_example_docstring,
)
from diffusers.utils import replace_example_docstring
from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from huggingface_hub import hf_hub_download, snapshot_download
......
[build-system]
requires = [
"setuptools",
"torch>=2.5",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
[tool.isort]
profile = "black"
known_first_party = ["nunchaku"]
line_length = 120
[tool.setuptools.packages.find]
include = ["nunchaku"]
[tool.black]
line-length = 120
target-version = ['py311']
[tool.ruff]
line-length = 140
[tool.ruff.lint]
select = ["E", "W", "F"]
ignore = ["F401"]
line-length = 120
[project]
dynamic = ["version"]
......@@ -29,3 +22,15 @@ dependencies = [
"huggingface_hub",
]
requires-python = ">=3.10"
[build-system]
requires = [
"setuptools",
"torch>=2.5",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["nunchaku"]
......@@ -26,4 +26,4 @@ bash scripts/build_linux_wheel_torch2.7_cu128.sh "3.13" "2.7" "12.8"
bash scripts/build_linux_wheel_cu128.sh "3.10" "2.8" "12.8"
bash scripts/build_linux_wheel_cu128.sh "3.11" "2.8" "12.8"
bash scripts/build_linux_wheel_cu128.sh "3.12" "2.8" "12.8"
bash scripts/build_linux_wheel_cu128.sh "3.13" "2.8" "12.8"
\ No newline at end of file
bash scripts/build_linux_wheel_cu128.sh "3.13" "2.8" "12.8"
......@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile --no-cache \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
......@@ -27,4 +27,4 @@ docker build -f docker/Dockerfile.torch27 --no-cache \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
......@@ -39,4 +39,4 @@ docker build -f docker/Dockerfile.torch28 --no-cache \
-t lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION} .
docker push lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
\ No newline at end of file
docker rmi lmxyy/nunchaku:${NUNCHAKU_VERSION}-py${PYTHON_VERSION}-torch${TORCH_VERSION}-cuda${CUDA_VERSION}
......@@ -35,4 +35,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
"
......@@ -33,4 +33,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
"
......@@ -33,4 +33,4 @@ docker run --rm \
export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
"
......@@ -4,4 +4,4 @@ set -ex
docker run --rm \
-v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda12.4 \
bash -c "cd /nunchaku && rm -rf *"
\ No newline at end of file
bash -c "cd /nunchaku && rm -rf *"
......@@ -6,7 +6,7 @@ import sys
import setuptools
import torch
from packaging import version as packaging_version
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension):
......
This diff is collapsed.
......@@ -7,7 +7,7 @@
#include "layernorm.h"
#include <pybind11/functional.h>
namespace pybind11 {
class function;
class function;
}
enum class AttentionImpl {
......@@ -18,7 +18,7 @@ enum class AttentionImpl {
class AdaLayerNormZeroSingle : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
struct Output {
Tensor x;
......@@ -40,7 +40,7 @@ private:
class AdaLayerNormZero : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
using GEMM = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
struct Output {
Tensor x;
......@@ -49,6 +49,7 @@ public:
Tensor scale_mlp;
Tensor gate_mlp;
};
public:
AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device);
Output forward(Tensor x, Tensor emb);
......@@ -85,9 +86,15 @@ private:
class FluxSingleTransformerBlock : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device);
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
FluxSingleTransformerBlock(int dim,
int num_attention_heads,
int attention_head_dim,
int mlp_ratio,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);
public:
......@@ -111,10 +118,21 @@ private:
class JointTransformerBlock : public Module {
public:
static constexpr bool USE_4BIT = true;
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio);
using GEMM = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;
JointTransformerBlock(int dim,
int num_attention_heads,
int attention_head_dim,
bool context_pre_only,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
std::tuple<Tensor, Tensor> forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio);
public:
const int dim;
......@@ -143,35 +161,35 @@ private:
class FluxModel : public Module {
public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer = false);
std::tuple<Tensor, Tensor> forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
Tensor forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer = false);
std::tuple<Tensor, Tensor> forward_layer(size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl);
void set_residual_callback(std::function<Tensor(const Tensor&)> cb);
void set_residual_callback(std::function<Tensor(const Tensor &)> cb);
public:
const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
std::function<Tensor(const Tensor&)> residual_callback;
std::function<Tensor(const Tensor &)> residual_callback;
private:
bool offload;
};
\ No newline at end of file
};
This diff is collapsed.
......@@ -37,6 +37,7 @@ public:
float lora_scale;
const Device device;
public:
Tensor qweight;
Tensor wscales;
......@@ -69,12 +70,18 @@ public:
Tensor forward(Tensor x);
Tensor forward_silu(Tensor x);
std::variant<Tensor, QuantizedActivation> forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
void forward(
Tensor x, Tensor out,
Tensor pool = {}, Tensor norm_q = {}, Tensor norm_k = {}, Tensor rotary_emb = {},
Tensor out_q = {}, Tensor out_k = {}, Tensor out_v = {}, int numTokens = 0
);
std::variant<Tensor, QuantizedActivation> forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
void forward(Tensor x,
Tensor out,
Tensor pool = {},
Tensor norm_q = {},
Tensor norm_k = {},
Tensor rotary_emb = {},
Tensor out_q = {},
Tensor out_k = {},
Tensor out_v = {},
int numTokens = 0);
std::variant<Tensor, QuantizedActivation>
forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM = nullptr);
Tensor forward_quant(QuantizedActivation qact);
public:
......@@ -86,7 +93,7 @@ public:
const int in_features_pad;
const int out_features_pad;
const bool use_fp4;
int lora_rank;
std::vector<float> lora_scales; // every 16 ranks share a scale
......@@ -118,13 +125,16 @@ public:
Tensor act;
Tensor ascales;
};
public:
GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device);
public:
QuantizedActivation quantize(Tensor x, bool fuse_glu);
QuantizedActivation quantize(Tensor x, bool fuse_glu);
Tensor forward_quant(QuantizedActivation qact);
Tensor forward(Tensor x) { return forward_quant(quantize(x, false)); }
Tensor forward(Tensor x) {
return forward_quant(quantize(x, false));
}
public:
const int in_features;
......@@ -149,4 +159,4 @@ public:
public:
Tensor weight;
Tensor bias;
};
\ No newline at end of file
};
......@@ -10,8 +10,8 @@ void Module::copyWithCast(Tensor dst, Tensor src) {
nunchaku::kernels::cast(src, dst);
} else {
Tensor tmp;
tmp.buffer = dst.buffer;
tmp.shape = dst.shape;
tmp.buffer = dst.buffer;
tmp.shape = dst.shape;
tmp.scalarType = src.scalarType;
tmp.copy_(src);
nunchaku::kernels::cast(tmp, dst);
......
......@@ -7,7 +7,7 @@
class Module {
protected:
enum class ParamFlags : int {
None = 0,
None = 0,
Optional = 1,
LazyLoad = 2,
};
......@@ -19,7 +19,7 @@ protected:
Tensor src;
};
struct Param {
Tensor *tensor = nullptr;
Tensor *tensor = nullptr;
ParamFlags flags = ParamFlags::None;
TensorLazyLoadInfo lazyInfo;
......@@ -50,7 +50,7 @@ public:
std::string getPrefix() const {
std::string fullName = getFullName();
std::string prefix = fullName.empty() ? "" : fullName + ".";
std::string prefix = fullName.empty() ? "" : fullName + ".";
return prefix;
}
......@@ -80,7 +80,7 @@ public:
continue;
}
// keep loading params if param is not released
}
}
this->loadParam(key, *param.tensor, src);
// tensor->copy_(src);
}
......@@ -99,8 +99,8 @@ public:
}
TensorLazyLoadInfo &lazy = param.lazyInfo;
Tensor &dst = *param.tensor;
Tensor src = lazy.src;
Tensor &dst = *param.tensor;
Tensor src = lazy.src;
if (dst.valid()) {
continue;
......@@ -108,7 +108,8 @@ public:
dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device);
if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) {
throw std::runtime_error(spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key));
throw std::runtime_error(
spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key));
}
m->loadParam(key, dst, src);
}
......@@ -127,14 +128,10 @@ public:
});
}
void setLazyLoad(bool val) {
traverse([val](Module *m) {
m->enabledLazyLoad = val;
});
traverse([val](Module *m) { m->enabledLazyLoad = val; });
}
void setAutoCastFP16(bool val) {
traverse([val](Module *m) {
m->enabledAutoCastFP16 = val;
});
traverse([val](Module *m) { m->enabledAutoCastFP16 = val; });
}
protected:
......@@ -143,7 +140,8 @@ protected:
Tensor::FP16,
Tensor::BF16,
};
if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) && whitelist.contains(src.scalar_type())) {
if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) &&
whitelist.contains(src.scalar_type())) {
copyWithCast(dst, src);
} else {
dst.copy_(src);
......@@ -159,7 +157,7 @@ protected:
};
ChildrenRegisterHelper registerChildren(Module &module, std::string name) {
module.parent = this;
module.name = name;
module.name = name;
children.push_back(&module);
return ChildrenRegisterHelper(*this);
}
......@@ -174,13 +172,13 @@ protected:
ParamsRegisterHelper registerParams(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
if (param.valid()) {
params[name].tensor = &param;
params[name].flags = flags;
params[name].flags = flags;
if (checkFlag(flags, ParamFlags::LazyLoad) && param.valid()) {
TensorLazyLoadInfo &lazy = params[name].lazyInfo;
lazy.shape = param.shape;
lazy.type = param.dtype();
lazy.device = param.device();
lazy.shape = param.shape;
lazy.type = param.dtype();
lazy.device = param.device();
}
}
return ParamsRegisterHelper(*this);
......@@ -204,12 +202,12 @@ private:
void copyWithCast(Tensor dst, Tensor src);
public:
Module *parent = nullptr;
Module *parent = nullptr;
std::string name = "";
std::vector<Module *> children;
std::map<std::string, Param> params;
bool enabledLazyLoad = false;
bool enabledLazyLoad = false;
bool enabledAutoCastFP16 = true;
};
......@@ -226,12 +224,11 @@ struct LayerOffloadHelper {
std::unique_ptr<CUDAEventWrapper> eventComputeDone;
std::unique_ptr<CUDAEventWrapper> eventLoadDone;
LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload)
: offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload)
{
LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload)
: offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload) {
if (offload) {
streamCompute = std::make_unique<CUDAStreamWrapper>();
streamLoad = std::make_unique<CUDAStreamWrapper>();
streamLoad = std::make_unique<CUDAStreamWrapper>();
needWorkaround = checkWorkaround();
if (needWorkaround) {
......@@ -280,7 +277,7 @@ private:
}
eventComputeDone = std::move(nextComputeDone);
eventLoadDone = std::move(nextLoadDone);
eventLoadDone = std::move(nextLoadDone);
workaroundSynchronize();
}
......@@ -304,12 +301,12 @@ private:
return false;
}
}
#ifdef _WIN32
#ifdef _WIN32
return true;
#else
#else
return false;
#endif
#endif
}
void workaroundFlush() {
if (!needWorkaround) {
......@@ -323,4 +320,4 @@ private:
}
checkCUDA(cudaEventSynchronize(eventComputeDone->event));
}
};
\ No newline at end of file
};
......@@ -10,18 +10,11 @@
using spdlog::fmt_lib::format;
using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_pad(ceilDiv(dim, 128) * 128),
qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device),
out_proj(dim_pad, dim, bias, use_fp4, dtype, device),
pag_to_v(std::nullopt)
{
registerChildren
(qkv_proj, "qkv_proj")
(out_proj, "out_proj")
;
SanaLinearAttention::SanaLinearAttention(
int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device)
: dim(dim), dim_pad(ceilDiv(dim, 128) * 128), qkv_proj(dim, dim_pad * 3, bias, use_fp4, dtype, device),
out_proj(dim_pad, dim, bias, use_fp4, dtype, device), pag_to_v(std::nullopt) {
registerChildren(qkv_proj, "qkv_proj")(out_proj, "out_proj");
if (pag) {
pag_to_v.emplace(dim, dim_pad, bias, use_fp4, dtype, device);
......@@ -33,8 +26,8 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
constexpr int HEAD_DIM = 32;
assert(x.ndims() == 3);
const int batch_size = x.shape[0];
const int num_tokens = x.shape[1];
const int batch_size = x.shape[0];
const int num_tokens = x.shape[1];
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
assert(x.shape[2] == dim);
......@@ -54,24 +47,38 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
auto qact = qkv_proj.quantize(x, false);
Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device());
Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device());
Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device());
kernels::gemm_w4a4(
qact.act,
qkv_proj.qweight,
{},
{},
qact.ascales,
qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {},
vk, q,
qact.is_unsigned, qkv_proj.lora_scales, false,
qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{},
{}, {}, {}, 0
);
kernels::gemm_w4a4(qact.act,
qkv_proj.qweight,
{},
{},
qact.ascales,
qkv_proj.wscales,
{},
{},
qact.lora_act,
qkv_proj.lora_up,
{},
{},
{},
{},
{},
qkv_proj.bias,
{},
vk,
q,
qact.is_unsigned,
qkv_proj.lora_scales,
false,
qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(),
qkv_proj.wcscales.numel() > 0 ? qkv_proj.wcscales : Tensor{},
{},
{},
{},
0);
debug("vk", vk);
debug("q", q);
......@@ -88,7 +95,6 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
q = q_unpad;
}
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q);
......@@ -109,14 +115,14 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
if (cfg) {
assert(batch_size % 3 == 0);
x_org = x.slice(0, 0, batch_size * 2 / 3);
x_ptb = x.slice(0, batch_size * 2 / 3, batch_size);
x_org = x.slice(0, 0, batch_size * 2 / 3);
x_ptb = x.slice(0, batch_size * 2 / 3, batch_size);
out_org = out.slice(0, 0, batch_size * 2 / 3);
out_ptb = out.slice(0, batch_size * 2 / 3, batch_size);
} else {
assert(batch_size % 2 == 0);
x_org = x.slice(0, 0, batch_size / 2);
x_ptb = x.slice(0, batch_size / 2, batch_size);
x_org = x.slice(0, 0, batch_size / 2);
x_ptb = x.slice(0, batch_size / 2, batch_size);
out_org = out.slice(0, 0, batch_size / 2);
out_ptb = out.slice(0, batch_size / 2, batch_size);
}
......@@ -129,17 +135,13 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return out;
}
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) :
num_heads(num_heads), head_dim(head_dim),
q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device)
{
registerChildren
(q_linear, "q_linear")
(kv_linear, "kv_linear")
(out_proj, "out_proj")
;
MultiHeadCrossAttention::MultiHeadCrossAttention(
int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device)
: num_heads(num_heads), head_dim(head_dim),
q_linear(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, use_fp4, dtype, device) {
registerChildren(q_linear, "q_linear")(kv_linear, "kv_linear")(out_proj, "out_proj");
}
Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt) {
......@@ -155,22 +157,28 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
assert(cu_seqlens_img.shape[0] == batch_size + 1);
assert(cu_seqlens_txt.shape[0] == batch_size + 1);
Tensor q = q_linear.forward(x).view({batch_size * num_tokens_img, num_heads, head_dim});
Tensor q = q_linear.forward(x).view({batch_size * num_tokens_img, num_heads, head_dim});
Tensor kv = kv_linear.forward(cond).view({num_tokens_txt, num_heads * 2, head_dim});
Tensor k = kv.slice(1, 0, num_heads);
Tensor v = kv.slice(1, num_heads, num_heads * 2);
Tensor attn_output = mha_varlen_fwd(
q, k, v,
cu_seqlens_img, cu_seqlens_txt,
num_tokens_img, num_tokens_txt,
0.0f,
pow(q.shape[-1], (-0.5)),
false, false,
-1, -1,
false
).front().view({batch_size, num_tokens_img, num_heads * head_dim});
Tensor attn_output = mha_varlen_fwd(q,
k,
v,
cu_seqlens_img,
cu_seqlens_txt,
num_tokens_img,
num_tokens_txt,
0.0f,
pow(q.shape[-1], (-0.5)),
false,
false,
-1,
-1,
false)
.front()
.view({batch_size, num_tokens_img, num_heads * head_dim});
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
......@@ -181,17 +189,13 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return out_proj.forward(attn_output);
}
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device),
point_conv(hidden_features, in_features, false, use_fp4, dtype, device)
{
registerChildren
(inverted_conv, "inverted_conv")
(depth_conv, "depth_conv")
(point_conv, "point_conv")
;
SanaGLUMBConv::SanaGLUMBConv(
int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device)
: in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device),
point_conv(hidden_features, in_features, false, use_fp4, dtype, device) {
registerChildren(inverted_conv, "inverted_conv")(depth_conv, "depth_conv")(point_conv, "point_conv");
}
Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
......@@ -203,33 +207,39 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
debug("inverted_conv_output", x);
x = depth_conv.forward(x);
debug("depth_conv_output", x);
x = x.view({x.shape[0], H * W, x.shape[-1]});
x = x.view({x.shape[0], H * W, x.shape[-1]});
auto qact = point_conv.quantize(x, true);
return point_conv.forward_quant(qact);
}
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
ff(hidden_size, intermediate_size, use_fp4, dtype, device),
norm1(hidden_size, 1e-6, false, dtype, device),
norm2(hidden_size, 1e-6, false, dtype, device)
{
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size,
int intermediate_size,
int num_cross_attention_heads,
bool pag,
bool use_fp4,
Tensor::ScalarType dtype,
Device device)
: hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
ff(hidden_size, intermediate_size, use_fp4, dtype, device), norm1(hidden_size, 1e-6, false, dtype, device),
norm2(hidden_size, 1e-6, false, dtype, device) {
this->scale_shift_table = Tensor::allocate({6, hidden_size}, dtype, device);
registerChildren
(attn, "attn")
(cross_attn, "cross_attn")
(ff, "ff")
;
registerChildren(attn, "attn")(cross_attn, "cross_attn")(ff, "ff");
registerParams
(this->scale_shift_table, "scale_shift_table")
;
registerParams(this->scale_shift_table, "scale_shift_table");
}
Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) {
Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg) {
nvtxRangePushA("SanaLinearTransformerBlock");
......@@ -257,7 +267,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
{
nvtxRangePushA("LinearAttention");
Tensor residual = hidden_states;
Tensor residual = hidden_states;
Tensor norm_hidden_states = norm1.forward(hidden_states);
kernels::mul_add_batch(norm_hidden_states, scale_msa, true, 1, shift_msa, true);
debug("norm_hidden_states_la", norm_hidden_states);
......@@ -311,9 +321,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return hidden_states;
}
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) :
config(config)
{
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) : config(config) {
const int inner_dim = config.num_attention_heads * config.attention_head_dim;
for (int i = 0; i < config.num_layers; i++) {
transformer_blocks.push_back(std::make_unique<SanaLinearTransformerBlock>(
......@@ -322,20 +330,34 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
config.num_cross_attention_heads,
std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
config.use_fp4,
dtype, device
));
dtype,
device));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
}
}
Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer) {
Tensor SanaModel::forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer) {
for (int i = (skip_first_layer ? 1 : 0); i < config.num_layers; i++) {
auto &&block = transformer_blocks[i];
hidden_states = block->forward(
hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
cfg
);
auto &&block = transformer_blocks[i];
hidden_states = block->forward(hidden_states,
encoder_hidden_states,
timestep,
cu_seqlens_img,
cu_seqlens_txt,
H,
W,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) !=
config.pag_layers.end(),
cfg);
}
return hidden_states;
}
......@@ -35,7 +35,7 @@ public:
private:
GEMM_W4A4 q_linear;
GEMM_F16 kv_linear;
GEMM_F16 kv_linear;
GEMM_W4A4 out_proj;
};
......@@ -57,9 +57,23 @@ private:
class SanaLinearTransformerBlock : public Module {
public:
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg);
SanaLinearTransformerBlock(int hidden_size,
int intermediate_size,
int num_cross_attention_heads,
bool pag,
bool use_fp4,
Tensor::ScalarType dtype,
Device device);
Tensor forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg);
public:
const int hidden_size;
......@@ -89,11 +103,20 @@ struct SanaConfig {
class SanaModel : public Module {
public:
SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer);
Tensor forward(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor timestep,
Tensor cu_seqlens_img,
Tensor cu_seqlens_txt,
int H,
int W,
bool pag,
bool cfg,
bool skip_first_layer);
public:
const SanaConfig config;
public:
std::vector<std::unique_ptr<SanaLinearTransformerBlock>> transformer_blocks;
};
\ No newline at end of file
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment