Commit 20525490 authored by Rongjin Yang's avatar Rongjin Yang Committed by GitHub
Browse files

04 cache dit (#165)

* has read BaseRunner -> DefaultRunner -> WanRunner's init

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* FirstBlock

* DualBlock

* DynamicBlock

* DynamicBlock

* DynamicBlock

* DynamicBlock

* DynamicBlock

* DualBlock Downsample Factor

* DualBlock Downsample Factor

* Downsample

* Downsample
parent 38d11b82
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"feature_caching": "DualBlock",
"residual_diff_threshold": 0.03,
"downsample_factor": 2
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"feature_caching": "DynamicBlock",
"residual_diff_threshold": 0.003,
"downsample_factor": 2
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"feature_caching": "FirstBlock",
"residual_diff_threshold": 0.02,
"downsample_factor": 2
}
......@@ -5,7 +5,6 @@ import numpy as np
import gc
# 1. TeaCaching
class WanTransformerInferTeaCaching(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
......@@ -682,3 +681,306 @@ class WanTransformerInferCustomCaching(WanTransformerInfer, BaseTaylorCachingTra
self.previous_e0_odd = None
torch.cuda.empty_cache()
class WanTransformerInferFirstBlock(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.prev_first_block_residual_even = None
self.prev_remaining_blocks_residual_even = None
self.prev_first_block_residual_odd = None
self.prev_remaining_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
x = super().infer_block(weights.blocks[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
x_residual = x - ori_x
del ori_x
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records[index] = should_calc
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
else:
index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
if self.config.enable_cfg:
self.switch_status()
return x
def calculate_should_calc(self, x_residual):
diff = 1.0
x_residual_downsampled = x_residual[..., :: self.downsample_factor]
if self.infer_conditional:
if self.prev_first_block_residual_even is not None:
t1 = self.prev_first_block_residual_even
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_first_block_residual_even = x_residual_downsampled
else:
if self.prev_first_block_residual_odd is not None:
t1 = self.prev_first_block_residual_odd
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_first_block_residual_odd = x_residual_downsampled
return diff >= self.residual_diff_threshold
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(1, self.blocks_num):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.infer_conditional:
self.prev_remaining_blocks_residual_even = x - ori_x
else:
self.prev_remaining_blocks_residual_odd = x - ori_x
del ori_x
return x
def infer_using_cache(self, x):
if self.infer_conditional:
return x.add_(self.prev_remaining_blocks_residual_even)
else:
return x.add_(self.prev_remaining_blocks_residual_odd)
def clear(self):
self.prev_first_block_residual_even = None
self.prev_remaining_blocks_residual_even = None
self.prev_first_block_residual_odd = None
self.prev_remaining_blocks_residual_odd = None
torch.cuda.empty_cache()
class WanTransformerInferDualBlock(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.prev_front_blocks_residual_even = None
self.prev_middle_blocks_residual_even = None
self.prev_front_blocks_residual_odd = None
self.prev_middle_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(0, 5):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
x_residual = x - ori_x
del ori_x
if self.infer_conditional:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records[index] = should_calc
if caching_records[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
else:
index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(x_residual)
self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index]:
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x = self.infer_using_cache(x)
for block_idx in range(self.blocks_num - 5, self.blocks_num):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.config.enable_cfg:
self.switch_status()
return x
def calculate_should_calc(self, x_residual):
diff = 1.0
x_residual_downsampled = x_residual[..., :: self.downsample_factor]
if self.infer_conditional:
if self.prev_front_blocks_residual_even is not None:
t1 = self.prev_front_blocks_residual_even
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_front_blocks_residual_even = x_residual_downsampled
else:
if self.prev_front_blocks_residual_odd is not None:
t1 = self.prev_front_blocks_residual_odd
t2 = x_residual_downsampled
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
self.prev_front_blocks_residual_odd = x_residual_downsampled
return diff >= self.residual_diff_threshold
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
for block_idx in range(5, self.blocks_num - 5):
x = super().infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.infer_conditional:
self.prev_middle_blocks_residual_even = x - ori_x
else:
self.prev_middle_blocks_residual_odd = x - ori_x
del ori_x
return x
def infer_using_cache(self, x):
if self.infer_conditional:
return x.add_(self.prev_middle_blocks_residual_even)
else:
return x.add_(self.prev_middle_blocks_residual_odd)
def clear(self):
self.prev_front_blocks_residual_even = None
self.prev_middle_blocks_residual_even = None
self.prev_front_blocks_residual_odd = None
self.prev_middle_blocks_residual_odd = None
torch.cuda.empty_cache()
class WanTransformerInferDynamicBlock(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.downsample_factor = self.config.downsample_factor
self.block_in_cache_even = {i: None for i in range(self.blocks_num)}
self.block_residual_cache_even = {i: None for i in range(self.blocks_num)}
self.block_in_cache_odd = {i: None for i in range(self.blocks_num)}
self.block_residual_cache_odd = {i: None for i in range(self.blocks_num)}
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
x = self.infer_block(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx):
ori_x = x.clone()
if self.infer_conditional:
if self.block_in_cache_even[block_idx] is not None:
should_calc = self.are_two_tensor_similar(self.block_in_cache_even[block_idx], x)
if should_calc:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x += self.block_residual_cache_even[block_idx]
else:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.block_in_cache_even[block_idx] = ori_x
self.block_residual_cache_even[block_idx] = x - ori_x
del ori_x
else:
if self.block_in_cache_odd[block_idx] is not None:
should_calc = self.are_two_tensor_similar(self.block_in_cache_odd[block_idx], x)
if should_calc:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
else:
x += self.block_residual_cache_odd[block_idx]
else:
x = super().infer_block(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
self.block_in_cache_odd[block_idx] = ori_x
self.block_residual_cache_odd[block_idx] = x - ori_x
del ori_x
return x
def are_two_tensor_similar(self, t1, t2):
diff = 1.0
t1_downsampled = t1[..., :: self.downsample_factor]
t2_downsampled = t2[..., :: self.downsample_factor]
mean_diff = (t1_downsampled - t2_downsampled).abs().mean()
mean_t1 = t1_downsampled.abs().mean()
diff = (mean_diff / mean_t1).item()
return diff >= self.residual_diff_threshold
def clear(self):
for i in range(self.blocks_num):
self.block_in_cache_even[i] = None
self.block_residual_cache_even[i] = None
self.block_in_cache_odd[i] = None
self.block_residual_cache_odd[i] = None
torch.cuda.empty_cache()
......@@ -12,12 +12,12 @@ from functools import partial
class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
self.task = config["task"]
self.task = config.task
self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"]
self.blocks_num = config.num_layers
self.phases_num = 4
self.num_heads = config["num_heads"]
self.head_dim = config["dim"] // config["num_heads"]
self.num_heads = config.num_heads
self.head_dim = config.dim // config.num_heads
self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None
if config.get("rotary_chunk", False):
......@@ -28,7 +28,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
if self.config["cpu_offload"]:
if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config:
......
......@@ -17,6 +17,9 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
WanTransformerInferTaylorCaching,
WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching,
WanTransformerInferFirstBlock,
WanTransformerInferDualBlock,
WanTransformerInferDynamicBlock,
)
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
......@@ -75,6 +78,12 @@ class WanModel:
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
elif self.config["feature_caching"] == "FirstBlock":
self.transformer_infer_class = WanTransformerInferFirstBlock
elif self.config["feature_caching"] == "DualBlock":
self.transformer_infer_class = WanTransformerInferDualBlock
elif self.config["feature_caching"] == "DynamicBlock":
self.transformer_infer_class = WanTransformerInferDynamicBlock
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -192,7 +201,7 @@ class WanModel:
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
if self.config.get("cpu_offload", False):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
......@@ -213,7 +222,7 @@ class WanModel:
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.config["cpu_offload"]:
if self.config.get("cpu_offload", False):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
......
......@@ -20,13 +20,13 @@ class DefaultRunner(BaseRunner):
super().__init__(config)
self.has_prompt_enhancer = False
self.progress_callback = None
if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
self.has_prompt_enhancer = True
if not self.check_sub_servers("prompt_enhancer"):
self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.")
if not self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = False
self.config.use_prompt_enhancer = False
self.set_init_device()
def init_modules(self):
......@@ -43,7 +43,7 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_t2v
def set_init_device(self):
if self.config["parallel_attn_type"]:
if self.config.parallel_attn_type:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
if self.config.cpu_offload:
......
......@@ -15,6 +15,9 @@ from lightx2v.models.schedulers.wan.feature_caching.scheduler import (
WanSchedulerTaylorCaching,
WanSchedulerAdaCaching,
WanSchedulerCustomCaching,
WanSchedulerFirstBlock,
WanSchedulerDualBlock,
WanSchedulerDynamicBlock,
)
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.utils import *
......@@ -169,6 +172,12 @@ class WanRunner(DefaultRunner):
scheduler = WanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = WanSchedulerCustomCaching(self.config)
elif self.config.feature_caching == "FirstBlock":
scheduler = WanSchedulerFirstBlock(self.config)
elif self.config.feature_caching == "DualBlock":
scheduler = WanSchedulerDualBlock(self.config)
elif self.config.feature_caching == "DynamicBlock":
scheduler = WanSchedulerDynamicBlock(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
......
......@@ -5,8 +5,8 @@ from lightx2v.utils.envs import *
class BaseScheduler:
def __init__(self, config):
self.config = config
self.step_index = 0
self.latents = None
self.step_index = 0
self.infer_steps = config.infer_steps
self.caching_records = [True] * config.infer_steps
self.flag_df = False
......
......@@ -35,3 +35,27 @@ class WanSchedulerCustomCaching(WanScheduler):
def clear(self):
self.transformer_infer.clear()
class WanSchedulerFirstBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerDualBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class WanSchedulerDynamicBlock(WanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment