"vscode:/vscode.git/clone" did not exist on "c1bbf5ddee41d8e96cda6aa2909f902374095498"
Commit 463544a1 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #2694 failed with stages
in 0 seconds
{
"0000.jpg": "给这个女生的脖子上戴一个带有红宝石的吊坠。",
"0001.png": "让她哭。",
"0002.jpg": "外套改用头层小牛皮制作。",
"0003.png": "将图像转换为漫画风格。",
"0004.jpg": "将文本 'TRAIN' 替换为 'PLANE'"
}
\ No newline at end of file
{
"0000.jpg": "Add pendant with a ruby around this girl's neck.",
"0001.png": "Let her cry.",
"0002.jpg": "Change the outerwear to be made of top-grain calfskin.",
"0003.png": "Change image to anime style.",
"0004.jpg": "Replace 'TRAIN' with 'PLANE'"
}
\ No newline at end of file
icon.png

50.3 KB

import argparse
import datetime
import json
import itertools
import math
import os
import time
from pathlib import Path
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image, ImageOps
from safetensors.torch import load_file
from torchvision.transforms import functional as F
from tqdm import tqdm
import sampling
from modules.autoencoder import AutoEncoder
from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder
from modules.model_edit import Step1XParams, Step1XEdit
def cudagc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True):
if Path(ckpt_path).suffix == ".safetensors":
state_dict = load_file(ckpt_path, device)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
missing, unexpected = model.load_state_dict(
state_dict, strict=strict, assign=assign
)
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
return model
def load_models(
dit_path=None,
ae_path=None,
qwen2vl_model_path=None,
device="cuda",
max_length=256,
dtype=torch.bfloat16,
):
qwen2vl_encoder = Qwen2VLEmbedder(
qwen2vl_model_path,
device=device,
max_length=max_length,
dtype=dtype,
)
with torch.device("meta"):
ae = AutoEncoder(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
step1x_params = Step1XParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
)
dit = Step1XEdit(step1x_params)
ae = load_state_dict(ae, ae_path, 'cpu')
dit = load_state_dict(
dit, dit_path, 'cpu'
)
ae = ae.to(dtype=torch.float32)
return ae, dit, qwen2vl_encoder
class ImageGenerator:
def __init__(
self,
dit_path=None,
ae_path=None,
qwen2vl_model_path=None,
device="cuda",
max_length=640,
dtype=torch.bfloat16,
quantized=False,
offload=False,
) -> None:
self.device = torch.device(device)
self.ae, self.dit, self.llm_encoder = load_models(
dit_path=dit_path,
ae_path=ae_path,
qwen2vl_model_path=qwen2vl_model_path,
max_length=max_length,
dtype=dtype,
)
if not quantized:
self.dit = self.dit.to(dtype=torch.bfloat16)
if not offload:
self.dit = self.dit.to(device=self.device)
self.ae = self.ae.to(device=self.device)
self.quantized = quantized
self.offload = offload
def prepare(self, prompt, img, ref_image, ref_image_raw):
bs, _, h, w = img.shape
bs, _, ref_h, ref_w = ref_image.shape
assert h == ref_h and w == ref_w
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
elif bs >= 1 and isinstance(prompt, str):
prompt = [prompt] * bs
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None]
ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :]
ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
if self.offload:
self.llm_encoder = self.llm_encoder.to(self.device)
txt, mask = self.llm_encoder(prompt, ref_image_raw)
if self.offload:
self.llm_encoder = self.llm_encoder.cpu()
cudagc()
txt_ids = torch.zeros(bs, txt.shape[1], 3)
img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2)
img_ids = torch.cat([img_ids, ref_img_ids], dim=-2)
return {
"img": img,
"mask": mask,
"img_ids": img_ids.to(img.device),
"llm_embedding": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
}
@staticmethod
def process_diff_norm(diff_norm, k):
pow_result = torch.pow(diff_norm, k)
result = torch.where(
diff_norm > 1.0,
pow_result,
torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm),
)
return result
def denoise(
self,
img: torch.Tensor,
img_ids: torch.Tensor,
llm_embedding: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: list[float],
cfg_guidance: float = 4.5,
mask=None,
show_progress=False,
timesteps_truncate=1.0,
):
if self.offload:
self.dit = self.dit.to(self.device)
if show_progress:
pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...')
else:
pbar = itertools.pairwise(timesteps)
for t_curr, t_prev in pbar:
if img.shape[0] == 1 and cfg_guidance != -1:
img = torch.cat([img, img], dim=0)
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
txt, vec = self.dit.connector(llm_embedding, t_vec, mask)
pred = self.dit(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
)
if cfg_guidance != -1:
cond, uncond = (
pred[0 : pred.shape[0] // 2, :],
pred[pred.shape[0] // 2 :, :],
)
if t_curr > timesteps_truncate:
diff = cond - uncond
diff_norm = torch.norm(diff, dim=(2), keepdim=True)
pred = uncond + cfg_guidance * (
cond - uncond
) / self.process_diff_norm(diff_norm, k=0.4)
else:
pred = uncond + cfg_guidance * (cond - uncond)
tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred
img_input_length = img.shape[1] // 2
img = torch.cat(
[
tem_img[:, :img_input_length],
img[ : img.shape[0] // 2, img_input_length:],
], dim=1
)
if self.offload:
self.dit = self.dit.cpu()
cudagc()
return img[:, :img.shape[1] // 2]
@staticmethod
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
@staticmethod
def load_image(image):
from PIL import Image
if isinstance(image, np.ndarray):
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
image = image.unsqueeze(0)
return image
elif isinstance(image, Image.Image):
image = F.to_tensor(image.convert("RGB"))
image = image.unsqueeze(0)
return image
elif isinstance(image, torch.Tensor):
return image
elif isinstance(image, str):
image = F.to_tensor(Image.open(image).convert("RGB"))
image = image.unsqueeze(0)
return image
else:
raise ValueError(f"Unsupported image type: {type(image)}")
def output_process_image(self, resize_img, image_size):
res_image = resize_img.resize(image_size)
return res_image
def input_process_image(self, img, img_size=512):
# 1. 打开图片
w, h = img.size
r = w / h
if w > h:
w_new = math.ceil(math.sqrt(img_size * img_size * r))
h_new = math.ceil(w_new / r)
else:
h_new = math.ceil(math.sqrt(img_size * img_size / r))
w_new = math.ceil(h_new * r)
h_new = math.ceil(h_new) // 16 * 16
w_new = math.ceil(w_new) // 16 * 16
img_resized = img.resize((w_new, h_new))
return img_resized, img.size
@torch.inference_mode()
def generate_image(
self,
prompt,
negative_prompt,
ref_images,
num_steps,
cfg_guidance,
seed,
num_samples=1,
init_image=None,
image2image_strength=0.0,
show_progress=False,
size_level=512,
):
assert num_samples == 1, "num_samples > 1 is not supported yet."
ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level)
width, height = ref_images_raw.width, ref_images_raw.height
ref_images_raw = self.load_image(ref_images_raw)
ref_images_raw = ref_images_raw.to(self.device)
if self.offload:
self.ae = self.ae.to(self.device)
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
if self.offload:
self.ae = self.ae.cpu()
cudagc()
seed = int(seed)
seed = torch.Generator(device="cpu").seed() if seed < 0 else seed
t0 = time.perf_counter()
if init_image is not None:
init_image = self.load_image(init_image)
init_image = init_image.to(self.device)
init_image = torch.nn.functional.interpolate(init_image, (height, width))
if self.offload:
self.ae = self.ae.to(self.device)
init_image = self.ae.encode(init_image.to() * 2 - 1)
if self.offload:
self.ae = self.ae.cpu()
cudagc()
x = torch.randn(
num_samples,
16,
height // 8,
width // 8,
device=self.device,
dtype=torch.bfloat16,
generator=torch.Generator(device=self.device).manual_seed(seed),
)
timesteps = sampling.get_schedule(
num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True
)
if init_image is not None:
t_idx = int((1 - image2image_strength) * num_steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
x = t * x + (1.0 - t) * init_image.to(x.dtype)
x = torch.cat([x, x], dim=0)
ref_images = torch.cat([ref_images, ref_images], dim=0)
ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0)
inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw)
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
x = self.denoise(
**inputs,
cfg_guidance=cfg_guidance,
timesteps=timesteps,
show_progress=show_progress,
timesteps_truncate=1.0,
)
x = self.unpack(x.float(), height, width)
if self.offload:
self.ae = self.ae.to(self.device)
x = self.ae.decode(x)
if self.offload:
self.ae = self.ae.cpu()
cudagc()
x = x.clamp(-1, 1)
x = x.mul(0.5).add(0.5)
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
images_list = []
for img in x.float():
images_list.append(self.output_process_image(F.to_pil_image(img), img_info))
return images_list
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint')
parser.add_argument('--input_dir', type=str, required=True, help='Path to the input image directory')
parser.add_argument('--output_dir', type=str, required=True, help='Path to the output image directory')
parser.add_argument('--json_path', type=str, required=True, help='Path to the JSON file containing image names and prompts')
parser.add_argument('--seed', type=int, default=42, help='Random seed for generation')
parser.add_argument('--num_steps', type=int, default=28, help='Number of diffusion steps')
parser.add_argument('--cfg_guidance', type=float, default=6.0, help='CFG guidance strength')
parser.add_argument('--size_level', default=512, type=int)
parser.add_argument('--offload', action='store_true', help='Use offload for large models')
parser.add_argument('--quantized', action='store_true', help='Use fp8 model weights')
args = parser.parse_args()
assert os.path.exists(args.input_dir), f"Input directory {args.input_dir} does not exist."
assert os.path.exists(args.json_path), f"JSON file {args.json_path} does not exist."
args.output_dir = args.output_dir.rstrip('/') + ('-offload' if args.offload else "") + ('-quantized' if args.quantized else "") + f"-{args.size_level}"
os.makedirs(args.output_dir, exist_ok=True)
image_and_prompts = json.load(open(args.json_path, 'r'))
image_edit = ImageGenerator(
ae_path=os.path.join(args.model_path, 'vae.safetensors'),
dit_path=os.path.join(args.model_path, "step1x-edit-i1258-FP8.safetensors" if args.quantized else "step1x-edit-i1258.safetensors"),
qwen2vl_model_path=('Qwen/Qwen2.5-VL-7B-Instruct'),
max_length=640,
quantized=args.quantized,
offload=args.offload,
)
time_list = []
for image_name, prompt in image_and_prompts.items():
image_path = os.path.join(args.input_dir, image_name)
output_path = os.path.join(args.output_dir, image_name)
start_time = time.time()
image = image_edit.generate_image(
prompt,
negative_prompt="",
ref_images=Image.open(image_path).convert("RGB"),
num_samples=1,
num_steps=args.num_steps,
cfg_guidance=args.cfg_guidance,
seed=args.seed,
show_progress=True,
size_level=args.size_level,
)[0]
print(f"Time taken: {time.time() - start_time:.2f} seconds")
time_list.append(time.time() - start_time)
image.save(
os.path.join(output_path), lossless=True
)
print(f'average time for {args.output_dir}: ', sum(time_list[1:]) / len(time_list[1:]))
if __name__ == "__main__":
main()
# 模型唯一标识
modelCode=1532
# 模型名称
modelName=Step1X-Edit_pytorch
# 模型描述
modelDescription=Step1X-Edit作为一种新的通用图像编辑算法,通过结合强大的多模态大语言模型和基于扩散的图像解码器,显著提高了图像编辑的性能。
# 应用场景
appScenario=推理,多模态,绘画,动漫,媒体,制造,广媒,家居,教育
# 框架类型
frameType=Pytorch
import math
import torch
import torch.nn.functional as F
try:
import flash_attn
from flash_attn.flash_attn_interface import (
_flash_attn_forward,
flash_attn_func,
flash_attn_varlen_func,
)
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
flash_attn_func = None
MEMORY_LAYOUT = {
# flash模式:
# 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
# 后处理: 保持形状不变
"flash": (
lambda x: x, # 保持形状
lambda x: x, # 保持形状
),
# torch/vanilla模式:
# 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
# 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
"torch": (
lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
):
"""
执行QKV自注意力计算
Args:
q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
drop_rate (float): 注意力矩阵的dropout概率
attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
causal (bool): 是否使用因果注意力(仅关注前面位置)
Returns:
torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
"""
# 获取预处理和后处理函数
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
# 应用预处理变换
q = pre_attn_layout(q) # 形状根据模式变化
k = pre_attn_layout(k)
v = pre_attn_layout(v)
if mode == "torch":
# 使用PyTorch原生的scaled_dot_product_attention
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
elif mode == "flash":
assert flash_attn_func is not None, "flash_attn_func未定义"
assert attn_mask is None, "不支持的注意力掩码"
x: torch.Tensor = flash_attn_func(
q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
) # type: ignore
elif mode == "vanilla":
# 手动实现注意力机制
scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
b, a, s, _ = q.shape # 获取形状参数
s1 = k.size(2) # 键值序列长度
# 初始化注意力偏置
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
# 处理因果掩码
if causal:
assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
# 生成下三角因果掩码
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
diagonal=0
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias = attn_bias.to(q.dtype)
# 处理自定义注意力掩码
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask # 允许类似ALiBi的位置偏置
# 计算注意力矩阵
attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
attn += attn_bias
# softmax和dropout
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
# 计算输出
x = attn @ v # [B,A,S,D]
else:
raise NotImplementedError(f"不支持的注意力模式: {mode}")
# 应用后处理变换
x = post_attn_layout(x) # 恢复原始维度顺序
# 合并注意力头维度
b, s, a, d = x.shape
out = x.reshape(b, s, -1) # [B,S,A*D]
return out
# Modified from Flux
#
# Copyright 2024 Black Forest Labs
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from einops import rearrange
from torch import Tensor, nn
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, *tuple(ch_mult))
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
scale_factor: float,
shift_factor: float,
):
super().__init__()
self.encoder = Encoder(
resolution=resolution,
in_channels=in_channels,
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
z_channels=z_channels,
)
self.decoder = Decoder(
resolution=resolution,
in_channels=in_channels,
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
z_channels=z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = scale_factor
self.shift_factor = shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))
import torch
from qwen_vl_utils import process_vision_info
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
from torchvision.transforms import ToPILImage
to_pil = ToPILImage()
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
Here are examples of how to transform or refine prompts:
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
User Prompt:'''
def split_string(s):
# 将中文引号替换为英文引号
s = s.replace("“", '"').replace("”", '"') # use english quotes
result = []
# 标记是否在引号内
in_quotes = False
temp = ""
# 遍历字符串中的每个字符及其索引
for idx, char in enumerate(s):
# 如果字符是引号且索引大于 155
if char == '"' and idx > 155:
# 将引号添加到临时字符串
temp += char
# 如果不在引号内
if not in_quotes:
# 将临时字符串添加到结果列表
result.append(temp)
# 清空临时字符串
temp = ""
# 切换引号状态
in_quotes = not in_quotes
continue
# 如果在引号内
if in_quotes:
# 如果字符是空格
if char.isspace():
pass # have space token
# 将字符用中文引号包裹后添加到结果列表
result.append("“" + char + "”")
else:
# 将字符添加到临时字符串
temp += char
# 如果临时字符串不为空
if temp:
# 将临时字符串添加到结果列表
result.append(temp)
return result
class Qwen25VL_7b_Embedder(torch.nn.Module):
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
super(Qwen25VL_7b_Embedder, self).__init__()
self.max_length = max_length
self.dtype = dtype
self.device = device
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
).to(torch.cuda.current_device())
self.model.requires_grad_(False)
self.processor = AutoProcessor.from_pretrained(
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
)
self.prefix = Qwen25VL_7b_PREFIX
def forward(self, caption, ref_images):
text_list = caption
embs = torch.zeros(
len(text_list),
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
)
hidden_states = torch.zeros(
len(text_list),
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
)
masks = torch.zeros(
len(text_list),
self.max_length,
dtype=torch.long,
device=torch.cuda.current_device(),
)
input_ids_list = []
attention_mask_list = []
emb_list = []
def split_string(s):
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
result = []
in_quotes = False
temp = ""
for idx,char in enumerate(s):
if char == '"' and idx>155:
temp += char
if not in_quotes:
result.append(temp)
temp = ""
in_quotes = not in_quotes
continue
if in_quotes:
if char.isspace():
pass # have space token
result.append("“" + char + "”")
else:
temp += char
if temp:
result.append(temp)
return result
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
messages = [{"role": "user", "content": []}]
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
# 再添加 text
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
old_inputs_ids = inputs.input_ids
text_split_list = split_string(text)
token_list = []
for text_each in text_split_list:
txt_inputs = self.processor(
text=text_each,
images=None,
videos=None,
padding=True,
return_tensors="pt",
)
token_each = txt_inputs.input_ids
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
token_each = token_each[:, 1:-1]
token_list.append(token_each)
else:
token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0)
.to("cuda")
)
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
outputs = self.model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"),
image_grid_thw=inputs.image_grid_thw.to("cuda"),
output_hidden_states=True,
)
emb = outputs["hidden_states"][-1]
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
: self.max_length
]
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long,
device=torch.cuda.current_device(),
)
return embs, masks
\ No newline at end of file
from typing import Optional
import torch
import torch.nn
from einops import rearrange
from torch import nn
from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
def get_activation_layer(act_type):
"""get activation layer
Args:
act_type (str): the activation type
Returns:
torch.nn.functional: the activation layer
"""
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
class IndividualTokenRefinerBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.need_CA = need_CA
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.self_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
self.mlp = MLP(
in_channels=hidden_size,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop_rate,
**factory_kwargs,
)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
if self.need_CA:
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
y: torch.Tensor = None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
if self.need_CA:
x = self.cross_attnblock(x, c, attn_mask, y)
# FFN Layer
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
return x
class CrossAttnBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.norm1_2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.self_attn_q = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.self_attn_kv = nn.Linear(
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
y: torch.Tensor=None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
norm_y = self.norm1_2(y)
q = self.self_attn_q(norm_x)
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
kv = self.self_attn_kv(norm_y)
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
return x
class IndividualTokenRefiner(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA:bool=False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.need_CA = need_CA
self.blocks = nn.ModuleList(
[
IndividualTokenRefinerBlock(
hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
need_CA=self.need_CA,
**factory_kwargs,
)
for _ in range(depth)
]
)
def forward(
self,
x: torch.Tensor,
c: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
y:torch.Tensor=None,
):
self_attn_mask = None
if mask is not None:
batch_size = mask.shape[0]
seq_len = mask.shape[1]
mask = mask.to(x.device)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
1, 1, seq_len, 1
)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
# avoids self-attention weight being NaN for padding tokens
self_attn_mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, self_attn_mask,y)
return x
class SingleTokenRefiner(torch.nn.Module):
"""
A single token refiner block for llm text embedding refine.
"""
def __init__(
self,
in_channels,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA:bool=False,
attn_mode: str = "torch",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.attn_mode = attn_mode
self.need_CA = need_CA
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
self.input_embedder = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
if self.need_CA:
self.input_embedder_CA = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
# Build timestep embedding layer
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
# Build context embedding layer
self.c_embedder = TextProjection(
in_channels, hidden_size, act_layer, **factory_kwargs
)
self.individual_token_refiner = IndividualTokenRefiner(
hidden_size=hidden_size,
heads_num=heads_num,
depth=depth,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
need_CA=need_CA,
**factory_kwargs,
)
def forward(
self,
x: torch.Tensor,
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
y: torch.LongTensor=None,
):
timestep_aware_representations = self.t_embedder(t)
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
context_aware_representations = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
if self.need_CA:
y = self.input_embedder_CA(y)
x = self.individual_token_refiner(x, c, mask, y)
else:
x = self.individual_token_refiner(x, c, mask)
return x
class Qwen2Connector(torch.nn.Module):
def __init__(
self,
# biclip_dim=1024,
in_channels=3584,
hidden_size=4096,
heads_num=32,
depth=2,
need_CA=False,
device=None,
dtype=torch.bfloat16,
):
super().__init__()
factory_kwargs = {"device": device, "dtype":dtype}
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
self.global_proj_out=nn.Linear(in_channels,768)
self.scale_factor = nn.Parameter(torch.zeros(1))
with torch.no_grad():
self.scale_factor.data += -(1 - 0.09)
def forward(self, x,t,mask):
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
x_mean = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(x.dtype))
global_out=self.global_proj_out(x_mean)
encoder_hidden_states = self.S(x,t,mask)
return encoder_hidden_states,global_out
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment