Commit fe9ccbda authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #64 from ModelTC/dev_cache

Update cache config and fix bugs.
parents a94695e5 7ec21ca6
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
},
"feature_caching": "Tea",
"coefficients": [
[2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01],
[-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
}
...@@ -13,5 +13,11 @@ ...@@ -13,5 +13,11 @@
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl", "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true "weight_auto_quant": true
}, },
"feature_caching": "Tea" "feature_caching": "Tea",
"coefficients": [
[8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
} }
...@@ -14,5 +14,11 @@ ...@@ -14,5 +14,11 @@
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl", "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true "weight_auto_quant": true
}, },
"feature_caching": "Tea" "feature_caching": "Tea",
"coefficients": [
[-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01],
[-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
} }
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
},
"feature_caching": "Tea",
"coefficients": [
[-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02],
[2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
}
{ {
"infer_steps": 40, "infer_steps": 20,
"target_video_length": 81, "target_video_length": 81,
"target_height": 480, "target_height": 480,
"target_width": 832, "target_width": 832,
...@@ -7,22 +7,29 @@ ...@@ -7,22 +7,29 @@
"seed": 42, "seed": 42,
"sample_guide_scale": 5, "sample_guide_scale": 5,
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": true, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "phase", "offload_granularity": "phase",
"t5_offload_granularity": "block", "t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_int8", "dit_quantized_ckpt": "/wan_cfg_models/Wan2.1-I2V-480P-cfg-blocks-fp8/",
"mm_config": { "mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
}, },
"t5_quantized": true, "t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-int8.pth", "t5_quantized_ckpt": "/wan_cfg_models/models_t5_umt5-xxl-enc-int8.pth",
"t5_quant_scheme": "int8", "t5_quant_scheme": "int8",
"clip_quantized": true, "clip_quantized": true,
"clip_quantized_ckpt": "/path/to/clip_int8.pth", "clip_quantized_ckpt": "/wan_cfg_models/clip-int8.pth",
"clip_quant_scheme": "int8", "clip_quant_scheme": "int8",
"use_tiling_vae": true, "use_tiling_vae": true,
"tiny_vae": true, "tiny_vae": true,
"tiny_vae_path": "/path/to/taew2_1.pth", "tiny_vae_path": "/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth",
"lazy_load": true "lazy_load": true,
"feature_caching": "Tea",
"coefficients": [
[2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01],
[-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01]
],
"use_ret_steps": true,
"teacache_thresh": 0.26
} }
...@@ -244,4 +244,4 @@ class MemoryBuffer: ...@@ -244,4 +244,4 @@ class MemoryBuffer:
with self.lock: with self.lock:
if not self.cache: if not self.cache:
return -1 return -1
return max((key[0] + 1) % 40 for key in self.cache.keys()) return (list(self.cache.keys())[-1][0] + 1) % 40
...@@ -120,12 +120,12 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -120,12 +120,12 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory() self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory() self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
if self.bias_name is not None: if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).pin_memory() self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
else: else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name) self.weight = self.lazy_load_file.get_tensor(self.weight_name)
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float() self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
if self.bias_name is not None: if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name) self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
if self.weight_need_transpose: if self.weight_need_transpose:
self.weight = self.weight.t() self.weight = self.weight.t()
......
...@@ -19,7 +19,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -19,7 +19,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
else: else:
rescale_func = np.poly1d(self.scheduler.coefficients) rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_even += rescale_func( self.scheduler.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.scheduler.previous_e0_even).abs().mean() / self.scheduler.previous_e0_even.abs().mean()).cpu().item() ((modulated_inp - self.scheduler.previous_e0_even.cuda()).abs().mean() / self.scheduler.previous_e0_even.cuda().abs().mean()).cpu().item()
) )
if self.scheduler.accumulated_rel_l1_distance_even < self.scheduler.teacache_thresh: if self.scheduler.accumulated_rel_l1_distance_even < self.scheduler.teacache_thresh:
should_calc_even = False should_calc_even = False
...@@ -27,6 +27,11 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -27,6 +27,11 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc_even = True should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0 self.scheduler.accumulated_rel_l1_distance_even = 0
self.scheduler.previous_e0_even = modulated_inp.clone() self.scheduler.previous_e0_even = modulated_inp.clone()
if self.config["cpu_offload"]:
self.scheduler.previous_e0_even = self.scheduler.previous_e0_even.cpu()
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
else: # odd -> unconditon else: # odd -> unconditon
self.scheduler.is_even = False self.scheduler.is_even = False
...@@ -36,7 +41,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -36,7 +41,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
else: else:
rescale_func = np.poly1d(self.scheduler.coefficients) rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_odd += rescale_func( self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.scheduler.previous_e0_odd).abs().mean() / self.scheduler.previous_e0_odd.abs().mean()).cpu().item() ((modulated_inp - self.scheduler.previous_e0_odd.cuda()).abs().mean() / self.scheduler.previous_e0_odd.cuda().abs().mean()).cpu().item()
) )
if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh: if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh:
should_calc_odd = False should_calc_odd = False
...@@ -44,10 +49,15 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -44,10 +49,15 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
should_calc_odd = True should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0 self.scheduler.accumulated_rel_l1_distance_odd = 0
self.scheduler.previous_e0_odd = modulated_inp.clone() self.scheduler.previous_e0_odd = modulated_inp.clone()
if self.config["cpu_offload"]:
self.scheduler.previous_e0_odd = self.scheduler.previous_e0_odd.cpu()
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
if self.scheduler.is_even: if self.scheduler.is_even:
if not should_calc_even: if not should_calc_even:
x += self.scheduler.previous_residual_even x += self.scheduler.previous_residual_even.cuda()
else: else:
ori_x = x.clone() ori_x = x.clone()
x = super().infer( x = super().infer(
...@@ -62,12 +72,13 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -62,12 +72,13 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
) )
self.scheduler.previous_residual_even = x - ori_x self.scheduler.previous_residual_even = x - ori_x
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.scheduler.previous_residual_even = self.scheduler.previous_residual_even.cpu()
ori_x = ori_x.to("cpu") ori_x = ori_x.to("cpu")
del ori_x del ori_x
torch.cuda.empty_cache() torch.cuda.empty_cache()
else: else:
if not should_calc_odd: if not should_calc_odd:
x += self.scheduler.previous_residual_odd x += self.scheduler.previous_residual_odd.cuda()
else: else:
ori_x = x.clone() ori_x = x.clone()
x = super().infer( x = super().infer(
...@@ -82,6 +93,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -82,6 +93,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
) )
self.scheduler.previous_residual_odd = x - ori_x self.scheduler.previous_residual_odd = x - ori_x
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.scheduler.previous_residual_odd = self.scheduler.previous_residual_odd.cpu()
ori_x = ori_x.to("cpu") ori_x = ori_x.to("cpu")
del ori_x del ori_x
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -59,7 +59,6 @@ class WanTransformerInfer: ...@@ -59,7 +59,6 @@ class WanTransformerInfer:
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
# bug
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
...@@ -139,8 +138,6 @@ class WanTransformerInfer: ...@@ -139,8 +138,6 @@ class WanTransformerInfer:
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights) self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr._async_prefetch_block(weights)
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
weights.blocks[block_idx].modulation.to_cuda() weights.blocks[block_idx].modulation.to_cuda()
......
...@@ -15,62 +15,14 @@ class WanSchedulerTeaCaching(WanScheduler): ...@@ -15,62 +15,14 @@ class WanSchedulerTeaCaching(WanScheduler):
self.previous_residual_even = None self.previous_residual_even = None
self.previous_residual_odd = None self.previous_residual_odd = None
self.use_ret_steps = self.config.use_ret_steps self.use_ret_steps = self.config.use_ret_steps
if self.use_ret_steps:
if self.config.task == "i2v": self.coefficients = self.config.coefficients[0]
if self.use_ret_steps: self.ret_steps = 5 * 2
if self.config.target_width == 480 or self.config.target_height == 480: self.cutoff_steps = self.config.infer_steps * 2
self.coefficients = [ else:
2.57151496e05, self.coefficients = self.config.coefficients[1]
-3.54229917e04, self.ret_steps = 1 * 2
1.40286849e03, self.cutoff_steps = self.config.infer_steps * 2 - 2
-1.35890334e01,
1.32517977e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
]
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
-18.82220707,
4.91518089,
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
elif self.config.task == "t2v":
if self.use_ret_steps:
if "1.3B" in self.config.model_path:
self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
if "14B" in self.config.model_path:
self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if "1.3B" in self.config.model_path:
self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
if "14B" in self.config.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
def clear(self): def clear(self):
if self.previous_e0_even is not None: if self.previous_e0_even is not None:
......
...@@ -579,7 +579,7 @@ def main(): ...@@ -579,7 +579,7 @@ def main():
model_type_keys_map = { model_type_keys_map = {
"wan_dit": { "wan_dit": {
"key_idx": 2, "key_idx": 2,
"target_keys": ["self_attn", "cross_attnffn"], "target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None, "ignore_key": None,
}, },
"hunyuan_dit": { "hunyuan_dit": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment