Commit 701075f4 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

refactor compiler (#301)

parent 60c421f4
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 360,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"compile": true,
"compile_shapes": [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 360,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"parallel": {
"seq_p_size": 8,
"seq_p_attn_type": "ulysses"
},
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"t5_quantized": true,
"t5_quant_scheme": "fp8",
"compile": true,
"compile_shapes": [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]
}
......@@ -33,7 +33,6 @@ class FlashAttn2Weight(AttnWeightTemplate):
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func(
q,
......@@ -62,7 +61,6 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
x = flash_attn_varlen_func_v3(
q,
......
......@@ -34,7 +34,6 @@ class SageAttn2Weight(AttnWeightTemplate):
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan":
......
......@@ -24,7 +24,6 @@ class TorchSDPAWeight(AttnWeightTemplate):
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
mask_map=None,
):
if q.ndim == 3:
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
......
......@@ -86,15 +86,19 @@ class UlyssesAttnWeight(AttnWeightTemplate):
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
# 处理图像注意力结果
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn, group=seq_p_group) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group)
torch.cuda.synchronize() # 确保CUDA操作完成
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn = torch.cat([img_attn, txt_attn], dim=0)
return attn # 返回最终的注意力结果
@torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group):
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn, group=seq_p_group) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成
return img_attn
......@@ -14,8 +14,6 @@ from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.infer import init_runner # noqa
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.utils.envs import CHECK_ENABLE_GRAPH_MODE
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config
......@@ -189,12 +187,7 @@ class PipelineWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
self.runner.init_modules()
if CHECK_ENABLE_GRAPH_MODE():
self.init_temp_params()
self.graph_runner = GraphRunner(self.runner)
self.run_func = self.graph_runner.run_pipeline
else:
self.run_func = self.runner.run_pipeline
self.run_func = self.runner.run_pipeline
def init_temp_params(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
......
import argparse
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
......@@ -23,14 +23,9 @@ from lightx2v.utils.utils import seed_all
def init_runner(config):
seed_all(config.seed)
if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config)
default_runner.init_modules()
runner = GraphRunner(default_runner)
else:
runner = RUNNER_REGISTER[config.model_cls](config)
runner.init_modules()
torch.set_grad_enabled(False)
runner = RUNNER_REGISTER[config.model_cls](config)
runner.init_modules()
return runner
......
......@@ -29,7 +29,6 @@ class HunyuanTransformerInfer(BaseTransformerInfer):
else:
self.infer_func = self._infer_without_offload
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
......
......@@ -24,7 +24,6 @@ class QwenImagePreInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, hidden_states, timestep, guidance, encoder_hidden_states_mask, encoder_hidden_states, img_shapes, txt_seq_lens, attention_kwargs):
hidden_states_0 = hidden_states
hidden_states = self.img_in(hidden_states)
......
import os
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
......@@ -46,3 +48,69 @@ class WanAudioModel(WanModel):
self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer
def get_graph_name(self, shape):
return f"graph_{shape[0]}x{shape[1]}"
def start_compile(self, shape):
graph_name = self.get_graph_name(shape)
logger.info(f"[Compile] Compile shape: {shape}, graph_name: {graph_name}")
target_video_length = self.config.get("target_video_length", 81)
latents_length = (target_video_length - 1) // 16 * 4 + 1
latents_h = shape[0] // self.config.vae_stride[1]
latents_w = shape[1] // self.config.vae_stride[2]
new_inputs = {}
new_inputs["text_encoder_output"] = {}
new_inputs["text_encoder_output"]["context"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()
new_inputs["text_encoder_output"]["context_null"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()
new_inputs["image_encoder_output"] = {}
new_inputs["image_encoder_output"]["clip_encoder_out"] = torch.randn(257, 1280, dtype=torch.bfloat16).cuda()
new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda()
new_inputs["audio_encoder_output"] = torch.randn(1, latents_length, 128, 1024, dtype=torch.bfloat16).cuda()
new_inputs["previmg_encoder_output"] = {}
new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
new_inputs["previmg_encoder_output"]["prev_mask"] = torch.randn(4, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
self.scheduler.latents = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
self.scheduler.timestep_input = torch.tensor([600.0], dtype=torch.float32).cuda()
self.scheduler.audio_adapter_t_emb = torch.randn(1, 3, 5120, dtype=torch.bfloat16).cuda()
self._infer_cond_uncond(new_inputs, infer_condition=True, graph_name=graph_name)
def compile(self, compile_shapes):
self.check_compile_shapes(compile_shapes)
self.enable_compile_mode("_infer_cond_uncond")
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda()
for shape in compile_shapes:
self.start_compile(shape)
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()
self.disable_compile_mode("_infer_cond_uncond")
logger.info(f"[Compile] Compile status: {self.get_compile_status()}")
def check_compile_shapes(self, compile_shapes):
for shape in compile_shapes:
assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]
def select_graph_for_compile(self):
logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}")
self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}")
logger.info(f"[Compile] Compile status: {self.get_compile_status()}")
......@@ -2,7 +2,6 @@ import os
import torch
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid,
)
......@@ -45,11 +44,6 @@ class WanCausVidModel(WanModel):
@torch.no_grad()
def infer(self, inputs, kv_start, kv_end):
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.transformer_weights.post_weights_to_cuda()
......
......@@ -8,7 +8,7 @@ class WanAudioPostInfer(WanPostInfer):
def __init__(self, config):
super().__init__(config)
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
@torch.no_grad()
def infer(self, x, pre_infer_out):
x = x[: pre_infer_out.seq_lens[0]]
......
......@@ -23,29 +23,30 @@ class WanAudioPreInfer(WanPreInfer):
).cuda()
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
self.rope_t_dim = d // 2 - 2 * (d // 6)
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
@torch.no_grad()
def infer(self, weights, inputs):
infer_condition, latents, timestep_input = self.scheduler.infer_condition, self.scheduler.latents, self.scheduler.timestep_input
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
hidden_states = self.scheduler.latents
hidden_states = latents
if self.config.model_cls != "wan2.2_audio":
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=0)
x = hidden_states
t = self.scheduler.timestep_input
t = timestep_input
if self.scheduler.infer_condition:
if infer_condition:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype)
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(latents.dtype)
num_channels, _, height, width = x.shape
ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
......@@ -53,15 +54,15 @@ class WanAudioPreInfer(WanPreInfer):
if ref_num_channels != num_channels:
zero_padding = torch.zeros(
(num_channels - ref_num_channels, ref_num_frames, height, width),
dtype=self.scheduler.latents.dtype,
device=self.scheduler.latents.device,
dtype=latents.dtype,
device=latents.device,
)
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=0)
y = ref_image_encoder
# embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.int32, device=x.device).unsqueeze(0)
grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
......@@ -70,8 +71,8 @@ class WanAudioPreInfer(WanPreInfer):
x = torch.cat([x, y], dim=1).squeeze(0)
####for r2v # zero temporl component corresponding to ref embeddings
self.freqs[grid_sizes[0][0] :, : self.rope_t_dim] = 0
grid_sizes[:, 0] += 1
self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0
grid_sizes_t += 1
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.sensitive_layer_dtype != self.infer_dtype:
......@@ -85,15 +86,14 @@ class WanAudioPreInfer(WanPreInfer):
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
if self.sensitive_layer_dtype != self.infer_dtype:
out = weights.text_embedding_0.apply(stacked.squeeze(0).to(self.sensitive_layer_dtype))
out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype))
else:
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = weights.text_embedding_0.apply(context.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.clean_cuda_cache:
del out, stacked
del out
torch.cuda.empty_cache()
if self.task == "i2v" and self.config.get("use_image_encoder", True):
......@@ -114,7 +114,7 @@ class WanAudioPreInfer(WanPreInfer):
del context_clip
torch.cuda.empty_cache()
grid_sizes = GridOutput(tensor=grid_sizes, tuple=(grid_sizes[0][0].item(), grid_sizes[0][1].item(), grid_sizes[0][2].item()))
grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=grid_sizes,
......
......@@ -46,7 +46,6 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
self.crossattn_cache = crossattn_cache
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end)
......@@ -127,7 +126,6 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
mask_map=self.mask_map,
)
else:
# TODO: Implement parallel attention for causvid inference
......
......@@ -14,6 +14,7 @@ class WanPostInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.no_grad()
def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes.tuple)
......
......@@ -23,7 +23,6 @@ class WanPreInfer:
).cuda()
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
self.cfg_scale = config.get("cfg_scale", 4.0)
self.infer_dtype = GET_DTYPE()
......@@ -32,6 +31,7 @@ class WanPreInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.no_grad()
def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents
t = self.scheduler.timestep_input
......@@ -61,7 +61,7 @@ class WanPreInfer:
# embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.int32, device=x.device).unsqueeze(0)
grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
......@@ -84,15 +84,14 @@ class WanPreInfer:
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
if self.sensitive_layer_dtype != self.infer_dtype:
out = weights.text_embedding_0.apply(stacked.squeeze(0).to(self.sensitive_layer_dtype))
out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype))
else:
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = weights.text_embedding_0.apply(context.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.clean_cuda_cache:
del out, stacked
del out
torch.cuda.empty_cache()
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
......@@ -117,7 +116,7 @@ class WanPreInfer:
del context_clip
torch.cuda.empty_cache()
grid_sizes = GridOutput(tensor=grid_sizes, tuple=(grid_sizes[0][0].item(), grid_sizes[0][1].item(), grid_sizes[0][2].item()))
grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=grid_sizes,
......
......@@ -26,7 +26,6 @@ class WanTransformerInfer(BaseTransformerInfer):
else:
self.apply_rotary_emb_func = apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
......@@ -49,6 +48,7 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.no_grad()
def infer(self, weights, pre_infer_out):
x = self.infer_main_blocks(weights.blocks, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed)
......@@ -186,7 +186,6 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
mask_map=self.mask_map,
)
y = phase.self_attn_o.apply(attn_out)
......
......@@ -7,7 +7,6 @@ import torch.nn.functional as F
from loguru import logger
from safetensors import safe_open
from lightx2v.common.ops.attn import MaskMap
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching,
......@@ -30,6 +29,7 @@ from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.custom_compiler import CompiledMethodsMixin, compiled_method
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
......@@ -39,11 +39,12 @@ except ImportError:
gguf = None
class WanModel:
class WanModel(CompiledMethodsMixin):
pre_weight_class = WanPreWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__()
self.model_path = model_path
self.config = config
self.cpu_offload = self.config.get("cpu_offload", False)
......@@ -340,11 +341,6 @@ class WanModel:
self.pre_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config["enable_cfg"]:
if self.config["cfg_parallel"]:
# ==================== CFG Parallel Processing ====================
......@@ -378,7 +374,7 @@ class WanModel:
self.pre_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
@compiled_method()
@torch.no_grad()
def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition
......
......@@ -31,8 +31,7 @@ class CogvideoxRunner(DefaultRunner):
return vae_model, vae_model
def init_scheduler(self):
scheduler = CogvideoxXDPMScheduler(self.config)
self.model.set_scheduler(scheduler)
self.scheduler = CogvideoxXDPMScheduler(self.config)
def run_text_encoder(self, text, img):
text_encoder_output = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment