"vscode:/vscode.git/clone" did not exist on "1ddb70f9e5faaa70bfa333a053ec6b1ccc83c311"
Commit bf0813a6 authored by Hyunsung Lee's avatar Hyunsung Lee Committed by Zhekai Zhang
Browse files

Add SanaModel caching

parent 65d7e47a
import torch
from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.25)
# WarmUp
prompt = "A cute 🐼 eating 🎋, ink drawing style"
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m.png")
...@@ -7,6 +7,8 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs): ...@@ -7,6 +7,8 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
pipe_cls_name = pipe.__class__.__name__ pipe_cls_name = pipe.__class__.__name__
if pipe_cls_name.startswith("Flux"): if pipe_cls_name.startswith("Flux"):
from .flux import apply_cache_on_pipe as apply_cache_on_pipe_fn from .flux import apply_cache_on_pipe as apply_cache_on_pipe_fn
elif pipe_cls_name.startswith("Sana"):
from .sana import apply_cache_on_pipe as apply_cache_on_pipe_fn
else: else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}") raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs) return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
...@@ -13,7 +13,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_ ...@@ -13,7 +13,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
cached_transformer_blocks = torch.nn.ModuleList( cached_transformer_blocks = torch.nn.ModuleList(
[ [
utils.CachedTransformerBlocks( utils.FluxCachedTransformerBlocks(
transformer=transformer, transformer=transformer,
residual_diff_threshold=residual_diff_threshold, residual_diff_threshold=residual_diff_threshold,
return_hidden_states_first=False, return_hidden_states_first=False,
......
import functools
import unittest
import torch
from diffusers import DiffusionPipeline, SanaTransformer2DModel
from ...caching import utils
def apply_cache_on_transformer(transformer: SanaTransformer2DModel, *, residual_diff_threshold=0.12):
if getattr(transformer, "_is_cached", False):
return transformer
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.SanaCachedTransformerBlocks(
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
)
]
)
original_forward = transformer.forward
@functools.wraps(original_forward)
def new_forward(self, *args, **kwargs):
with unittest.mock.patch.object(self, "transformer_blocks", cached_transformer_blocks):
return original_forward(*args, **kwargs)
transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import contextlib import contextlib
import dataclasses import dataclasses
from collections import defaultdict from collections import defaultdict
from typing import DefaultDict, Dict from typing import DefaultDict, Dict, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -81,18 +81,19 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False): ...@@ -81,18 +81,19 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
@torch.compiler.disable @torch.compiler.disable
def apply_prev_hidden_states_residual( def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_residual = get_buffer("hidden_states_residual") hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before" assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states hidden_states = hidden_states_residual + hidden_states
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous() if encoder_hidden_states is not None:
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
...@@ -108,8 +109,124 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals ...@@ -108,8 +109,124 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
) )
return can_use_cache return can_use_cache
class SanaCachedTransformerBlocks(nn.Module):
def __init__(
self,
*,
transformer=None,
residual_diff_threshold,
verbose: bool = False,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.verbose = verbose
def forward(self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None,
):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 2:
if batch_size > 2:
print("Batch size > 2 (for SANA CFG)"
" currently not supported")
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
skip_first_layer=False,
)
return hidden_states
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block.forward_layer_at(
0,
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
)
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
if self.verbose:
print("Cache hit!!!")
hidden_states, _ = apply_prev_hidden_states_residual(hidden_states, None)
else:
if self.verbose:
print("Cache miss!!!")
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
del first_hidden_states_residual
hidden_states, hidden_states_residual = self.call_remaining_transformer_blocks(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
post_patch_height=post_patch_height,
post_patch_width=post_patch_width,
)
set_buffer("hidden_states_residual", hidden_states_residual)
torch._dynamo.graph_break()
return hidden_states
def call_remaining_transformer_blocks(self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None
):
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
hidden_states = first_transformer_block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
skip_first_layer=True
)
hidden_states_residual = hidden_states - original_hidden_states
return hidden_states, hidden_states_residual
class CachedTransformerBlocks(nn.Module): class FluxCachedTransformerBlocks(nn.Module):
def __init__( def __init__(
self, self,
*, *,
......
...@@ -26,15 +26,16 @@ public: ...@@ -26,15 +26,16 @@ public:
} }
torch::Tensor forward( torch::Tensor forward(
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor timestep, torch::Tensor timestep,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt, torch::Tensor cu_seqlens_txt,
int H, int H,
int W, int W,
bool pag, bool pag,
bool cfg) bool cfg,
bool skip_first_layer = false)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
...@@ -54,7 +55,8 @@ public: ...@@ -54,7 +55,8 @@ public:
from_torch(cu_seqlens_img), from_torch(cu_seqlens_img),
from_torch(cu_seqlens_txt), from_torch(cu_seqlens_txt),
H, W, H, W,
pag, cfg pag, cfg,
skip_first_layer
); );
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
...@@ -65,15 +67,15 @@ public: ...@@ -65,15 +67,15 @@ public:
torch::Tensor forward_layer( torch::Tensor forward_layer(
int64_t idx, int64_t idx,
torch::Tensor hidden_states, torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor timestep, torch::Tensor timestep,
torch::Tensor cu_seqlens_img, torch::Tensor cu_seqlens_img,
torch::Tensor cu_seqlens_txt, torch::Tensor cu_seqlens_txt,
int H, int H,
int W, int W,
bool pag, bool pag,
bool cfg) bool cfg)
{ {
checkModel(); checkModel();
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
......
...@@ -30,6 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module): ...@@ -30,6 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
timestep: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
skip_first_layer: Optional[bool] = False
): ):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -69,6 +70,61 @@ class NunchakuSanaTransformerBlocks(nn.Module): ...@@ -69,6 +70,61 @@ class NunchakuSanaTransformerBlocks(nn.Module):
width, width,
batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0 batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
True, # TODO: find a way to detect if we are doing CFG True, # TODO: find a way to detect if we are doing CFG
skip_first_layer,
)
.to(original_dtype)
.to(original_device)
)
def forward_layer_at(
self,
idx: int,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
):
batch_size = hidden_states.shape[0]
img_tokens = hidden_states.shape[1]
txt_tokens = encoder_hidden_states.shape[1]
original_dtype = hidden_states.dtype
original_device = hidden_states.device
assert encoder_attention_mask is not None
assert encoder_attention_mask.shape == (batch_size, 1, txt_tokens)
mask = encoder_attention_mask.reshape(batch_size, txt_tokens)
nunchaku_encoder_hidden_states = encoder_hidden_states[mask > -9000]
cu_seqlens_txt = F.pad((mask > -9000).sum(dim=1).cumsum(dim=0), pad=(1, 0), value=0).to(torch.int32)
cu_seqlens_img = torch.arange(
0, (batch_size + 1) * img_tokens, img_tokens, dtype=torch.int32, device=self.device
)
if height is None and width is None:
height = width = int(img_tokens**0.5)
elif height is None:
height = img_tokens // width
elif width is None:
width = img_tokens // height
assert height * width == img_tokens
return (
self.m.forward_layer(
idx,
hidden_states.to(self.dtype).to(self.device),
nunchaku_encoder_hidden_states.to(self.dtype).to(self.device),
timestep.to(self.dtype).to(self.device),
cu_seqlens_img.to(self.device),
cu_seqlens_txt.to(self.device),
height,
width,
batch_size % 3 == 0, # pag is set when loading the model, FIXME: pag_scale == 0
True, # TODO: find a way to detect if we are doing CFG
) )
.to(original_dtype) .to(original_dtype)
.to(original_device) .to(original_device)
......
#include <iostream>
#include "SanaModel.h" #include "SanaModel.h"
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "flash_api.h" #include "flash_api.h"
...@@ -8,6 +10,7 @@ ...@@ -8,6 +10,7 @@
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
using namespace nunchaku; using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_pad(ceilDiv(dim, 128) * 128), dim_pad(ceilDiv(dim, 128) * 128),
...@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_ ...@@ -28,7 +31,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, bool use_
Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
constexpr int HEAD_DIM = 32; constexpr int HEAD_DIM = 32;
assert(x.ndims() == 3); assert(x.ndims() == 3);
const int batch_size = x.shape[0]; const int batch_size = x.shape[0];
const int num_tokens = x.shape[1]; const int num_tokens = x.shape[1];
...@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -45,7 +48,7 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
x_pad.slice(0, i, i + 1).slice(1, 0, num_tokens).copy_(x.slice(0, i, i + 1)); x_pad.slice(0, i, i + 1).slice(1, 0, num_tokens).copy_(x.slice(0, i, i + 1));
} }
x = x_pad; x = x_pad;
} }
...@@ -55,14 +58,14 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) { ...@@ -55,14 +58,14 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device()); Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device());
kernels::gemm_w4a4( kernels::gemm_w4a4(
qact.act, qact.act,
qkv_proj.qweight, qkv_proj.qweight,
{}, {},
{}, {},
qact.ascales, qact.ascales,
qkv_proj.wscales, qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {}, {}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {},
vk, q, vk, q,
qact.is_unsigned, qkv_proj.lora_scales, false, qact.is_unsigned, qkv_proj.lora_scales, false,
qkv_proj.use_fp4, qkv_proj.use_fp4,
*qkv_proj.wtscale.data_ptr<float>(), *qkv_proj.wtscale.data_ptr<float>(),
...@@ -118,12 +121,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) { ...@@ -118,12 +121,12 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
} }
this->forward(x_org, out_org); this->forward(x_org, out_org);
Tensor v_ptb = this->pag_to_v.value().forward(x_ptb); Tensor v_ptb = this->pag_to_v.value().forward(x_ptb);
this->out_proj.forward(v_ptb, out_ptb); this->out_proj.forward(v_ptb, out_ptb);
return out; return out;
} }
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) : MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device) :
num_heads(num_heads), head_dim(head_dim), num_heads(num_heads), head_dim(head_dim),
...@@ -143,7 +146,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -143,7 +146,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
assert(cond.ndims() == 2); assert(cond.ndims() == 2);
assert(cu_seqlens_img.ndims() == 1); assert(cu_seqlens_img.ndims() == 1);
assert(cu_seqlens_txt.ndims() == 1); assert(cu_seqlens_txt.ndims() == 1);
const int batch_size = x.shape[0]; const int batch_size = x.shape[0];
const int num_tokens_img = x.shape[1]; const int num_tokens_img = x.shape[1];
const int num_tokens_txt = cond.shape[0]; const int num_tokens_txt = cond.shape[0];
...@@ -163,21 +166,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -163,21 +166,21 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
num_tokens_img, num_tokens_txt, num_tokens_img, num_tokens_txt,
0.0f, 0.0f,
pow(q.shape[-1], (-0.5)), pow(q.shape[-1], (-0.5)),
false, false, false, false,
-1, -1, -1, -1,
false false
).front().view({batch_size, num_tokens_img, num_heads * head_dim}); ).front().view({batch_size, num_tokens_img, num_heads * head_dim});
// Tensor attn_output = mha_fwd(q, k, v, // Tensor attn_output = mha_fwd(q, k, v,
// 0.0f, // 0.0f,
// pow(q.shape[-1], (-0.5)), // pow(q.shape[-1], (-0.5)),
// false, -1, -1, false // false, -1, -1, false
// ).front().view({B, N, num_heads * head_dim}); // ).front().view({B, N, num_heads * head_dim});
return out_proj.forward(attn_output); return out_proj.forward(attn_output);
} }
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device) :
in_features(in_features), hidden_features(hidden_features), in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device), inverted_conv(in_features, hidden_features * 2, true, use_fp4, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device), depth_conv(hidden_features * 2, true, dtype, device),
...@@ -204,7 +207,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) { ...@@ -204,7 +207,7 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return point_conv.forward_quant(qact); return point_conv.forward_quant(qact);
} }
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) : SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads), hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, use_fp4, dtype, device), attn(hidden_size, false, pag, use_fp4, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device), cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, use_fp4, dtype, device),
...@@ -240,7 +243,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -240,7 +243,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
kernels::mul_add_batch(timestep, {}, false, 0, this->scale_shift_table, false); kernels::mul_add_batch(timestep, {}, false, 0, this->scale_shift_table, false);
debug("shifted_timestep", timestep); debug("shifted_timestep", timestep);
std::array<Tensor, 6> chunked; std::array<Tensor, 6> chunked;
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
chunked[i] = timestep.slice(1, i, i + 1); chunked[i] = timestep.slice(1, i, i + 1);
...@@ -299,7 +302,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -299,7 +302,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
nvtxRangePop(); nvtxRangePop();
} }
nvtxRangePop(); nvtxRangePop();
debug("hidden_states_out", hidden_states); debug("hidden_states_out", hidden_states);
...@@ -307,7 +310,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_ ...@@ -307,7 +310,7 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_
return hidden_states; return hidden_states;
} }
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) : SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) :
config(config) config(config)
{ {
const int inner_dim = config.num_attention_heads * config.attention_head_dim; const int inner_dim = config.num_attention_heads * config.attention_head_dim;
...@@ -324,8 +327,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) ...@@ -324,8 +327,8 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
} }
} }
Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) { Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer) {
for (int i = 0; i < config.num_layers; i++) { for (int i = (skip_first_layer ? 1 : 0); i < config.num_layers; i++) {
auto &&block = transformer_blocks[i]; auto &&block = transformer_blocks[i];
hidden_states = block->forward( hidden_states = block->forward(
hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W, hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W,
......
...@@ -89,7 +89,7 @@ struct SanaConfig { ...@@ -89,7 +89,7 @@ struct SanaConfig {
class SanaModel : public Module { class SanaModel : public Module {
public: public:
SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device); SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg); Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg, bool skip_first_layer);
public: public:
const SanaConfig config; const SanaConfig config;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment