Commit 40381d5a authored by helloyongyang's avatar helloyongyang
Browse files

support cache with changing res

parent 20525490
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"changing_resolution": true,
"resolution_rate": [1.0, 0.75],
"changing_resolution_steps": [10, 35],
"feature_caching": "Tea",
"coefficients": [
[-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02],
[2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
],
"use_ret_steps": false,
"teacache_thresh": 0.1
}
...@@ -5,7 +5,20 @@ import numpy as np ...@@ -5,7 +5,20 @@ import numpy as np
import gc import gc
class WanTransformerInferTeaCaching(WanTransformerInfer): class WanTransformerInferCaching(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.must_calc_steps = []
if self.config.get("changing_resolution", False):
self.must_calc_steps = self.config["changing_resolution_steps"]
def must_calc(self, step_index):
if step_index in self.must_calc_steps:
return True
return False
class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cnt = 0 self.cnt = 0
...@@ -87,7 +100,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -87,7 +100,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc = self.calculate_should_calc(embed, embed0) should_calc = self.calculate_should_calc(embed, embed0)
self.scheduler.caching_records[index] = should_calc self.scheduler.caching_records[index] = should_calc
if caching_records[index]: if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -99,7 +112,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -99,7 +112,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc = self.calculate_should_calc(embed, embed0) should_calc = self.calculate_should_calc(embed, embed0)
self.scheduler.caching_records_2[index] = should_calc self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]: if caching_records_2[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -169,7 +182,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -169,7 +182,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTransformerInfer): class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCachingTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -199,7 +212,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -199,7 +212,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records = self.scheduler.caching_records caching_records = self.scheduler.caching_records
if caching_records[index]: if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
...@@ -208,7 +221,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -208,7 +221,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2 caching_records_2 = self.scheduler.caching_records_2
if caching_records_2[index]: if caching_records_2[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
...@@ -305,7 +318,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -305,7 +318,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInfer, BaseTaylorCachingTra
torch.cuda.empty_cache() torch.cuda.empty_cache()
class WanTransformerInferAdaCaching(WanTransformerInfer): class WanTransformerInferAdaCaching(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -322,7 +335,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer): ...@@ -322,7 +335,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records = self.scheduler.caching_records caching_records = self.scheduler.caching_records
if caching_records[index]: if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# 1. calculate the skipped step length # 1. calculate the skipped step length
...@@ -338,7 +351,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer): ...@@ -338,7 +351,7 @@ class WanTransformerInferAdaCaching(WanTransformerInfer):
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records = self.scheduler.caching_records_2 caching_records = self.scheduler.caching_records_2
if caching_records[index]: if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# 1. calculate the skipped step length # 1. calculate the skipped step length
...@@ -518,7 +531,7 @@ class AdaArgs: ...@@ -518,7 +531,7 @@ class AdaArgs:
self.spatial_dim = 1536 self.spatial_dim = 1536
class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTransformerInfer): class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCachingTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cnt = 0 self.cnt = 0
...@@ -605,7 +618,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -605,7 +618,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.scheduler.caching_records[index] = should_calc self.scheduler.caching_records[index] = should_calc
if caching_records[index]: if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -617,7 +630,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -617,7 +630,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) should_calc = self.calculate_should_calc(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.scheduler.caching_records_2[index] = should_calc self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]: if caching_records_2[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -683,7 +696,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -683,7 +696,7 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
torch.cuda.empty_cache() torch.cuda.empty_cache()
class WanTransformerInferFirstBlock(WanTransformerInfer): class WanTransformerInferFirstBlock(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -707,7 +720,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer): ...@@ -707,7 +720,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
should_calc = self.calculate_should_calc(x_residual) should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records[index] = should_calc self.scheduler.caching_records[index] = should_calc
if caching_records[index]: if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -719,7 +732,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer): ...@@ -719,7 +732,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
should_calc = self.calculate_should_calc(x_residual) should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records_2[index] = should_calc self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]: if caching_records_2[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -788,7 +801,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer): ...@@ -788,7 +801,7 @@ class WanTransformerInferFirstBlock(WanTransformerInfer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
class WanTransformerInferDualBlock(WanTransformerInfer): class WanTransformerInferDualBlock(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -822,7 +835,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer): ...@@ -822,7 +835,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
should_calc = self.calculate_should_calc(x_residual) should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records[index] = should_calc self.scheduler.caching_records[index] = should_calc
if caching_records[index]: if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -834,7 +847,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer): ...@@ -834,7 +847,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
should_calc = self.calculate_should_calc(x_residual) should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records_2[index] = should_calc self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]: if caching_records_2[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
...@@ -915,7 +928,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer): ...@@ -915,7 +928,7 @@ class WanTransformerInferDualBlock(WanTransformerInfer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
class WanTransformerInferDynamicBlock(WanTransformerInfer): class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold self.residual_diff_threshold = config.residual_diff_threshold
...@@ -938,7 +951,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer): ...@@ -938,7 +951,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer):
if self.infer_conditional: if self.infer_conditional:
if self.block_in_cache_even[block_idx] is not None: if self.block_in_cache_even[block_idx] is not None:
should_calc = self.are_two_tensor_similar(self.block_in_cache_even[block_idx], x) should_calc = self.are_two_tensor_similar(self.block_in_cache_even[block_idx], x)
if should_calc: if should_calc or self.must_calc(block_idx):
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x += self.block_residual_cache_even[block_idx] x += self.block_residual_cache_even[block_idx]
...@@ -953,7 +966,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer): ...@@ -953,7 +966,7 @@ class WanTransformerInferDynamicBlock(WanTransformerInfer):
else: else:
if self.block_in_cache_odd[block_idx] is not None: if self.block_in_cache_odd[block_idx] is not None:
should_calc = self.are_two_tensor_similar(self.block_in_cache_odd[block_idx], x) should_calc = self.are_two_tensor_similar(self.block_in_cache_odd[block_idx], x)
if should_calc: if should_calc or self.must_calc(block_idx):
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else: else:
x += self.block_residual_cache_odd[block_idx] x += self.block_residual_cache_odd[block_idx]
......
...@@ -8,16 +8,11 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER ...@@ -8,16 +8,11 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.changing_resolution.scheduler import ( from lightx2v.models.schedulers.wan.changing_resolution.scheduler import (
WanScheduler4ChangingResolution, WanScheduler4ChangingResolutionInterface,
) )
from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTeaCaching, WanSchedulerCaching,
WanSchedulerTaylorCaching, WanSchedulerTaylorCaching,
WanSchedulerAdaCaching,
WanSchedulerCustomCaching,
WanSchedulerFirstBlock,
WanSchedulerDualBlock,
WanSchedulerDynamicBlock,
) )
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
...@@ -159,27 +154,19 @@ class WanRunner(DefaultRunner): ...@@ -159,27 +154,19 @@ class WanRunner(DefaultRunner):
return vae_encoder, vae_decoder return vae_encoder, vae_decoder
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
scheduler_class = WanScheduler
elif self.config.feature_caching == "TaylorSeer":
scheduler_class = WanSchedulerTaylorCaching
elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock"]:
scheduler_class = WanSchedulerCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
if self.config.get("changing_resolution", False): if self.config.get("changing_resolution", False):
scheduler = WanScheduler4ChangingResolution(self.config) scheduler = WanScheduler4ChangingResolutionInterface(scheduler_class, self.config)
else: else:
if self.config.feature_caching == "NoCaching": scheduler = scheduler_class(self.config)
scheduler = WanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = WanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = WanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
elif self.config.feature_caching == "FirstBlock":
scheduler = WanSchedulerFirstBlock(self.config)
elif self.config.feature_caching == "DualBlock":
scheduler = WanSchedulerDualBlock(self.config)
elif self.config.feature_caching == "DynamicBlock":
scheduler = WanSchedulerDynamicBlock(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img): def run_text_encoder(self, text, img):
......
...@@ -2,9 +2,18 @@ import torch ...@@ -2,9 +2,18 @@ import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanScheduler4ChangingResolution(WanScheduler): class WanScheduler4ChangingResolutionInterface:
def __new__(cls, father_scheduler, config):
class NewClass(WanScheduler4ChangingResolution, father_scheduler):
def __init__(self, config):
father_scheduler.__init__(self, config)
WanScheduler4ChangingResolution.__init__(self, config)
return NewClass(config)
class WanScheduler4ChangingResolution:
def __init__(self, config): def __init__(self, config):
super().__init__(config)
if "resolution_rate" not in config: if "resolution_rate" not in config:
config["resolution_rate"] = [0.75] config["resolution_rate"] = [0.75]
if "changing_resolution_steps" not in config: if "changing_resolution_steps" not in config:
......
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanSchedulerTeaCaching(WanScheduler): class WanSchedulerCaching(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -9,53 +9,10 @@ class WanSchedulerTeaCaching(WanScheduler): ...@@ -9,53 +9,10 @@ class WanSchedulerTeaCaching(WanScheduler):
self.transformer_infer.clear() self.transformer_infer.clear()
class WanSchedulerTaylorCaching(WanScheduler): class WanSchedulerTaylorCaching(WanSchedulerCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
pattern = [True, False, False, False] pattern = [True, False, False, False]
self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
def clear(self):
self.transformer_infer.clear()
class WanSchedulerAdaCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerCustomCaching(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerFirstBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerDualBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerDynamicBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment