Commit 20525490 authored by Rongjin Yang's avatar Rongjin Yang Committed by GitHub
Browse files

04 cache dit (#165)

* has read BaseRunner -> DefaultRunner -> WanRunner's init

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* DualBlock

* DynamicBlock

* DynamicBlock

* DynamicBlock

* DynamicBlock

* DynamicBlock

* DualBlock Downsample Factor

* DualBlock Downsample Factor

* Downsample

* Downsample
parent 38d11b82
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"feature_caching": "DualBlock",
"residual_diff_threshold": 0.03,
"downsample_factor": 2
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"feature_caching": "DynamicBlock",
"residual_diff_threshold": 0.003,
"downsample_factor": 2
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"feature_caching": "FirstBlock",
"residual_diff_threshold": 0.02,
"downsample_factor": 2
}
...@@ -5,7 +5,6 @@ import numpy as np ...@@ -5,7 +5,6 @@ import numpy as np
import gc import gc
# 1. TeaCaching
class WanTransformerInferTeaCaching(WanTransformerInfer): class WanTransformerInferTeaCaching(WanTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -682,3 +681,306 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra ...@@ -682,3 +681,306 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
self.previous_e0_odd = None self.previous_e0_odd = None
torch.cuda.empty_cache() torch.cuda.empty_cache()
class WanTransformerInferFirstBlock(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.prev_first_block_residual_even = None
self.prev_remaining_blocks_residual_even = None
self.prev_first_block_residual_odd = None
self.prev_remaining_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
x = super().infer_block(weights.blocks[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
x_residual = x - ori_x
del ori_x
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records[index] = should_calc
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
else:
index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
if self.config.enable_cfg:
self.switch_status()
return x
def calculate_should_calc(self, x_residual):
diff = 1.0
x_residual_downsampled = x_residual[..., :: self.downsample_factor]
if self.infer_conditional:
if self.prev_first_block_residual_even is not None:
t1 = self.prev_first_block_residual_even
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_first_block_residual_even = x_residual_downsampled
else:
if self.prev_first_block_residual_odd is not None:
t1 = self.prev_first_block_residual_odd
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_first_block_residual_odd = x_residual_downsampled
return diff >= self.residual_diff_threshold
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(1, self.blocks_num):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.infer_conditional:
self.prev_remaining_blocks_residual_even = x - ori_x
else:
self.prev_remaining_blocks_residual_odd = x - ori_x
del ori_x
return x
def infer_using_cache(self, x):
if self.infer_conditional:
return x.add_(self.prev_remaining_blocks_residual_even)
else:
return x.add_(self.prev_remaining_blocks_residual_odd)
def clear(self):
self.prev_first_block_residual_even = None
self.prev_remaining_blocks_residual_even = None
self.prev_first_block_residual_odd = None
self.prev_remaining_blocks_residual_odd = None
torch.cuda.empty_cache()
class WanTransformerInferDualBlock(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.prev_front_blocks_residual_even = None
self.prev_middle_blocks_residual_even = None
self.prev_front_blocks_residual_odd = None
self.prev_middle_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(0, 5):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
x_residual = x - ori_x
del ori_x
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records[index] = should_calc
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
else:
index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
for block_idx in range(self.blocks_num - 5, self.blocks_num):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.config.enable_cfg:
self.switch_status()
return x
def calculate_should_calc(self, x_residual):
diff = 1.0
x_residual_downsampled = x_residual[..., :: self.downsample_factor]
if self.infer_conditional:
if self.prev_front_blocks_residual_even is not None:
t1 = self.prev_front_blocks_residual_even
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_front_blocks_residual_even = x_residual_downsampled
else:
if self.prev_front_blocks_residual_odd is not None:
t1 = self.prev_front_blocks_residual_odd
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_front_blocks_residual_odd = x_residual_downsampled
return diff >= self.residual_diff_threshold
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(5, self.blocks_num - 5):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.infer_conditional:
self.prev_middle_blocks_residual_even = x - ori_x
else:
self.prev_middle_blocks_residual_odd = x - ori_x
del ori_x
return x
def infer_using_cache(self, x):
if self.infer_conditional:
return x.add_(self.prev_middle_blocks_residual_even)
else:
return x.add_(self.prev_middle_blocks_residual_odd)
def clear(self):
self.prev_front_blocks_residual_even = None
self.prev_middle_blocks_residual_even = None
self.prev_front_blocks_residual_odd = None
self.prev_middle_blocks_residual_odd = None
torch.cuda.empty_cache()
class WanTransformerInferDynamicBlock(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.downsample_factor = self.config.downsample_factor
self.block_in_cache_even = {i: None for i in range(self.blocks_num)}
self.block_residual_cache_even = {i: None for i in range(self.blocks_num)}
self.block_in_cache_odd = {i: None for i in range(self.blocks_num)}
self.block_residual_cache_odd = {i: None for i in range(self.blocks_num)}
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
x = self.infer_block(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx):
ori_x = x.clone()
if self.infer_conditional:
if self.block_in_cache_even[block_idx] is not None:
should_calc = self.are_two_tensor_similar(self.block_in_cache_even[block_idx], x)
if should_calc:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x += self.block_residual_cache_even[block_idx]
else:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.block_in_cache_even[block_idx] = ori_x
self.block_residual_cache_even[block_idx] = x - ori_x
del ori_x
else:
if self.block_in_cache_odd[block_idx] is not None:
should_calc = self.are_two_tensor_similar(self.block_in_cache_odd[block_idx], x)
if should_calc:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x += self.block_residual_cache_odd[block_idx]
else:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.block_in_cache_odd[block_idx] = ori_x
self.block_residual_cache_odd[block_idx] = x - ori_x
del ori_x
return x
def are_two_tensor_similar(self, t1, t2):
diff = 1.0
t1_downsampled = t1[..., :: self.downsample_factor]
t2_downsampled = t2[..., :: self.downsample_factor]
mean_diff = (t1_downsampled - t2_downsampled).abs().mean()
mean_t1 = t1_downsampled.abs().mean()
diff = (mean_diff / mean_t1).item()
return diff >= self.residual_diff_threshold
def clear(self):
for i in range(self.blocks_num):
self.block_in_cache_even[i] = None
self.block_residual_cache_even[i] = None
self.block_in_cache_odd[i] = None
self.block_residual_cache_odd[i] = None
torch.cuda.empty_cache()
...@@ -12,12 +12,12 @@ from functools import partial ...@@ -12,12 +12,12 @@ from functools import partial
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.task = config["task"] self.task = config.task
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"] self.blocks_num = config.num_layers
self.phases_num = 4 self.phases_num = 4
self.num_heads = config["num_heads"] self.num_heads = config.num_heads
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config.dim // config.num_heads
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
if config.get("rotary_chunk", False): if config.get("rotary_chunk", False):
...@@ -28,7 +28,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -28,7 +28,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None self.mask_map = None
if self.config["cpu_offload"]: if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0): if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2" assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
......
...@@ -17,6 +17,9 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import ...@@ -17,6 +17,9 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
WanTransformerInferTaylorCaching, WanTransformerInferTaylorCaching,
WanTransformerInferAdaCaching, WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching, WanTransformerInferCustomCaching,
WanTransformerInferFirstBlock,
WanTransformerInferDualBlock,
WanTransformerInferDynamicBlock,
) )
from safetensors import safe_open from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
...@@ -75,6 +78,12 @@ class WanModel: ...@@ -75,6 +78,12 @@ class WanModel:
self.transformer_infer_class = WanTransformerInferAdaCaching self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom": elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching self.transformer_infer_class = WanTransformerInferCustomCaching
elif self.config["feature_caching"] == "FirstBlock":
self.transformer_infer_class = WanTransformerInferFirstBlock
elif self.config["feature_caching"] == "DualBlock":
self.transformer_infer_class = WanTransformerInferDualBlock
elif self.config["feature_caching"] == "DynamicBlock":
self.transformer_infer_class = WanTransformerInferDynamicBlock
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
...@@ -192,7 +201,7 @@ class WanModel: ...@@ -192,7 +201,7 @@ class WanModel:
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
if self.config["cpu_offload"]: if self.config.get("cpu_offload", False):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
...@@ -213,7 +222,7 @@ class WanModel: ...@@ -213,7 +222,7 @@ class WanModel:
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.config["cpu_offload"]: if self.config.get("cpu_offload", False):
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
......
...@@ -20,13 +20,13 @@ class DefaultRunner(BaseRunner): ...@@ -20,13 +20,13 @@ class DefaultRunner(BaseRunner):
super().__init__(config) super().__init__(config)
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
self.progress_callback = None self.progress_callback = None
if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None: if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
self.has_prompt_enhancer = True self.has_prompt_enhancer = True
if not self.check_sub_servers("prompt_enhancer"): if not self.check_sub_servers("prompt_enhancer"):
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.") logger.warning("No prompt enhancer server available, disable prompt enhancer.")
if not self.has_prompt_enhancer: if not self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = False self.config.use_prompt_enhancer = False
self.set_init_device() self.set_init_device()
def init_modules(self): def init_modules(self):
...@@ -43,7 +43,7 @@ class DefaultRunner(BaseRunner): ...@@ -43,7 +43,7 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_t2v self.run_input_encoder = self._run_input_encoder_local_t2v
def set_init_device(self): def set_init_device(self):
if self.config["parallel_attn_type"]: if self.config.parallel_attn_type:
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank) torch.cuda.set_device(cur_rank)
if self.config.cpu_offload: if self.config.cpu_offload:
......
...@@ -15,6 +15,9 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import ( ...@@ -15,6 +15,9 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTaylorCaching, WanSchedulerTaylorCaching,
WanSchedulerAdaCaching, WanSchedulerAdaCaching,
WanSchedulerCustomCaching, WanSchedulerCustomCaching,
WanSchedulerFirstBlock,
WanSchedulerDualBlock,
WanSchedulerDynamicBlock,
) )
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
...@@ -169,6 +172,12 @@ class WanRunner(DefaultRunner): ...@@ -169,6 +172,12 @@ class WanRunner(DefaultRunner):
scheduler = WanSchedulerAdaCaching(self.config) scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom": elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config) scheduler = WanSchedulerCustomCaching(self.config)
elif self.config.feature_caching == "FirstBlock":
scheduler = WanSchedulerFirstBlock(self.config)
elif self.config.feature_caching == "DualBlock":
scheduler = WanSchedulerDualBlock(self.config)
elif self.config.feature_caching == "DynamicBlock":
scheduler = WanSchedulerDynamicBlock(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
......
...@@ -5,8 +5,8 @@ from lightx2v.utils.envs import * ...@@ -5,8 +5,8 @@ from lightx2v.utils.envs import *
class BaseScheduler: class BaseScheduler:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.step_index = 0
self.latents = None self.latents = None
self.step_index = 0
self.infer_steps = config.infer_steps self.infer_steps = config.infer_steps
self.caching_records = [True] * config.infer_steps self.caching_records = [True] * config.infer_steps
self.flag_df = False self.flag_df = False
......
...@@ -35,3 +35,27 @@ class WanSchedulerCustomCaching(WanScheduler): ...@@ -35,3 +35,27 @@ class WanSchedulerCustomCaching(WanScheduler):
def clear(self): def clear(self):
self.transformer_infer.clear() self.transformer_infer.clear()
class WanSchedulerFirstBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerDualBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerDynamicBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment