Commit fe9ccbda authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #64 from ModelTC/dev_cache

Update cache config and fix bugs.
parents a94695e5 7ec21ca6
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
},
"feature_caching": "Tea",
"coefficients": [
[2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01],
[-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
}
......@@ -13,5 +13,11 @@
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
},
"feature_caching": "Tea"
"feature_caching": "Tea",
"coefficients": [
[8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
}
......@@ -14,5 +14,11 @@
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
},
"feature_caching": "Tea"
"feature_caching": "Tea",
"coefficients": [
[-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01],
[-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
},
"feature_caching": "Tea",
"coefficients": [
[-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02],
[2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
}
{
"infer_steps": 40,
"infer_steps": 20,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
......@@ -7,22 +7,29 @@
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_int8",
"dit_quantized_ckpt": "/wan_cfg_models/Wan2.1-I2V-480P-cfg-blocks-fp8/",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
},
"t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-int8.pth",
"t5_quantized_ckpt": "/wan_cfg_models/models_t5_umt5-xxl-enc-int8.pth",
"t5_quant_scheme": "int8",
"clip_quantized": true,
"clip_quantized_ckpt": "/path/to/clip_int8.pth",
"clip_quantized_ckpt": "/wan_cfg_models/clip-int8.pth",
"clip_quant_scheme": "int8",
"use_tiling_vae": true,
"tiny_vae": true,
"tiny_vae_path": "/path/to/taew2_1.pth",
"lazy_load": true
"tiny_vae_path": "/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth",
"lazy_load": true,
"feature_caching": "Tea",
"coefficients": [
[2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01],
[-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
}
......@@ -244,4 +244,4 @@ class MemoryBuffer:
with self.lock:
if not self.cache:
return -1
return max((key[0] + 1) % 40 for key in self.cache.keys())
return (list(self.cache.keys())[-1][0] + 1) % 40
......@@ -120,12 +120,12 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).pin_memory()
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name)
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name)
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
if self.weight_need_transpose:
self.weight = self.weight.t()
......
......@@ -19,7 +19,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.scheduler.previous_e0_even).abs().mean() / self.scheduler.previous_e0_even.abs().mean()).cpu().item()
((modulated_inp - self.scheduler.previous_e0_even.cuda()).abs().mean() / self.scheduler.previous_e0_even.cuda().abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance_even < self.scheduler.teacache_thresh:
should_calc_even = False
......@@ -27,6 +27,11 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0
self.scheduler.previous_e0_even = modulated_inp.clone()
if self.config["cpu_offload"]:
self.scheduler.previous_e0_even = self.scheduler.previous_e0_even.cpu()
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
else: # odd -> unconditon
self.scheduler.is_even = False
......@@ -36,7 +41,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.scheduler.previous_e0_odd).abs().mean() / self.scheduler.previous_e0_odd.abs().mean()).cpu().item()
((modulated_inp - self.scheduler.previous_e0_odd.cuda()).abs().mean() / self.scheduler.previous_e0_odd.cuda().abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh:
should_calc_odd = False
......@@ -44,10 +49,15 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0
self.scheduler.previous_e0_odd = modulated_inp.clone()
if self.config["cpu_offload"]:
self.scheduler.previous_e0_odd = self.scheduler.previous_e0_odd.cpu()
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
if self.scheduler.is_even:
if not should_calc_even:
x += self.scheduler.previous_residual_even
x += self.scheduler.previous_residual_even.cuda()
else:
ori_x = x.clone()
x = super().infer(
......@@ -62,12 +72,13 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
)
self.scheduler.previous_residual_even = x - ori_x
if self.config["cpu_offload"]:
self.scheduler.previous_residual_even = self.scheduler.previous_residual_even.cpu()
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
else:
if not should_calc_odd:
x += self.scheduler.previous_residual_odd
x += self.scheduler.previous_residual_odd.cuda()
else:
ori_x = x.clone()
x = super().infer(
......@@ -82,6 +93,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
)
self.scheduler.previous_residual_odd = x - ori_x
if self.config["cpu_offload"]:
self.scheduler.previous_residual_odd = self.scheduler.previous_residual_odd.cpu()
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
......
......@@ -59,7 +59,6 @@ class WanTransformerInfer:
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# bug
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
if block_idx == 0:
......@@ -139,8 +138,6 @@ class WanTransformerInfer:
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr._async_prefetch_block(weights)
for block_idx in range(weights.blocks_num):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
weights.blocks[block_idx].modulation.to_cuda()
......
......@@ -15,62 +15,14 @@ class WanSchedulerTeaCaching(WanScheduler):
self.previous_residual_even = None
self.previous_residual_odd = None
self.use_ret_steps = self.config.use_ret_steps
if self.config.task == "i2v":
if self.use_ret_steps:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
-18.82220707,
4.91518089,
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
elif self.config.task == "t2v":
if self.use_ret_steps:
if "1.3B" in self.config.model_path:
self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
if "14B" in self.config.model_path:
self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if "1.3B" in self.config.model_path:
self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
if "14B" in self.config.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
if self.use_ret_steps:
self.coefficients = self.config.coefficients[0]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
self.coefficients = self.config.coefficients[1]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
def clear(self):
if self.previous_e0_even is not None:
......
......@@ -579,7 +579,7 @@ def main():
model_type_keys_map = {
"wan_dit": {
"key_idx": 2,
"target_keys": ["self_attn", "cross_attnffn"],
"target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None,
},
"hunyuan_dit": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment