Commit fba9754a authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Reconstruct] recon infer class (#228)

parent d8a2731b
import glob import glob
import os import os
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
...@@ -23,6 +24,7 @@ class WanAudioModel(WanModel): ...@@ -23,6 +24,7 @@ class WanAudioModel(WanModel):
super()._init_infer_class() super()._init_infer_class()
self.pre_infer_class = WanAudioPreInfer self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer
class Wan22MoeAudioModel(WanAudioModel): class Wan22MoeAudioModel(WanAudioModel):
......
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, compute_freqs_audio_dist
class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
def compute_freqs(self, q, grid_sizes, freqs):
if self.config["seq_parallel"]:
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
# Apply audio_dit if available
if pre_infer_out.audio_dit_blocks is not None and hasattr(self, "block_idx"):
for ipa_out in pre_infer_out.audio_dit_blocks:
if self.block_idx in ipa_out:
cur_modify = ipa_out[self.block_idx]
x = cur_modify["modify_func"](x, pre_infer_out.grid_sizes, **cur_modify["kwargs"])
return x
...@@ -2,13 +2,13 @@ import math ...@@ -2,13 +2,13 @@ import math
import torch import torch
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer
from ..utils import apply_rotary_emb, compute_freqs_causvid from ..utils import apply_rotary_emb, compute_freqs_causvid
class WanTransformerInferCausVid(WanTransformerInfer): class WanTransformerInferCausVid(WanOffloadTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_frames = config["num_frames"] self.num_frames = config["num_frames"]
......
...@@ -6,11 +6,10 @@ import torch ...@@ -6,11 +6,10 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from ..transformer_infer import WanTransformerInfer
class WanTransformerInferCaching(WanOffloadTransformerInfer):
class WanTransformerInferCaching(WanTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.must_calc_steps = [] self.must_calc_steps = []
......
import torch
from lightx2v.common.offload.manager import (
LazyWeightAsyncStreamManager,
WeightAsyncStreamManager,
)
from ..transformer_infer import WanTransformerInfer
class WanOffloadTransformerInfer(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
else:
offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_offload
else:
self.infer_func = self.infer_with_lazy_offload
elif offload_granularity == "phase":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_phases_offload
else:
self.infer_func = self.infer_with_phases_lazy_offload
elif offload_granularity == "model":
self.infer_func = self._infer_without_offload
if offload_granularity != "model":
if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
)
else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
def infer_with_offload(self, weights, x, pre_infer_out):
for block_idx in range(self.blocks_num):
self.block_idx = block_idx
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.swap_weights()
return x
def infer_with_lazy_offload(self, weights, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(self.blocks_num):
if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx)
block.to_cuda()
self.weights_stream_mgr.active_weights[0] = (block_idx, block)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context
torch.cuda.empty_cache()
return x
def infer_with_phases_offload(self, weights, x, pre_infer_out):
for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, pre_infer_out.embed0)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, pre_infer_out.context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context
torch.cuda.empty_cache()
return x
def infer_with_phases_lazy_offload(self, weights, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx)
phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key)
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (obj_key, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
(
(
_,
cur_phase_idx,
),
cur_phase,
) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, pre_infer_out.embed0)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, pre_infer_out.context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.weights_stream_mgr.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context
torch.cuda.empty_cache()
return x
...@@ -2,14 +2,10 @@ from functools import partial ...@@ -2,14 +2,10 @@ from functools import partial
import torch import torch
from lightx2v.common.offload.manager import (
LazyWeightAsyncStreamManager,
WeightAsyncStreamManager,
)
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio, compute_freqs_audio_dist, compute_freqs_dist from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_dist
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
...@@ -37,46 +33,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -37,46 +33,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else: else:
self.seq_p_group = None self.seq_p_group = None
self.infer_func = self.infer_without_offload
if self.config.get("cpu_offload", False):
# if torch.cuda.get_device_capability(0) == (9, 0):
# assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
else:
offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_offload
else:
self.infer_func = self._infer_with_lazy_offload
elif offload_granularity == "phase":
if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_phases_offload
else:
self.infer_func = self._infer_with_phases_lazy_offload
elif offload_granularity == "model":
self.infer_func = self._infer_without_offload
if offload_granularity != "model":
if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
)
else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
else:
self.infer_func = self._infer_without_offload
def _calculate_q_k_len(self, q, k_lens): def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
...@@ -86,36 +43,20 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -86,36 +43,20 @@ class WanTransformerInfer(BaseTransformerInfer):
def compute_freqs(self, q, grid_sizes, freqs): def compute_freqs(self, q, grid_sizes, freqs):
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
if "audio" in self.config.get("model_cls", ""): freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else: else:
if "audio" in self.config.get("model_cls", ""): freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i return freqs_i
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
x = self.infer_main_blocks(weights, pre_infer_out) x = self.infer_main_blocks(weights, pre_infer_out)
return self.infer_post_blocks(weights, x, pre_infer_out.embed) return self.infer_non_blocks(weights, x, pre_infer_out.embed)
def infer_main_blocks(self, weights, pre_infer_out): def infer_main_blocks(self, weights, pre_infer_out):
x = self.infer_func( x = self.infer_func(weights, pre_infer_out.x, pre_infer_out)
weights,
pre_infer_out.grid_sizes,
pre_infer_out.embed,
pre_infer_out.x,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
pre_infer_out.audio_dit_blocks,
)
return x return x
def infer_post_blocks(self, weights, x, e): def infer_non_blocks(self, weights, x, e):
if e.dim() == 2: if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
...@@ -139,214 +80,29 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -139,214 +80,29 @@ class WanTransformerInfer(BaseTransformerInfer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer_without_offload(self, weights, x, pre_infer_out):
for block_idx in range(self.blocks_num):
self.block_idx = block_idx
x = self.infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
audio_dit_blocks,
)
return x
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(
self.weights_stream_mgr.active_weights[0],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
audio_dit_blocks,
)
self.weights_stream_mgr.swap_weights()
return x
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(self.blocks_num):
if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx)
block.to_cuda()
self.weights_stream_mgr.active_weights[0] = (block_idx, block)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(
self.weights_stream_mgr.active_weights[0][1],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
audio_dit_blocks,
)
self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer_block(self, weights, x, pre_infer_out):
for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, embed0)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(weights.blocks_num):
self.block_idx = block_idx
for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx)
phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key)
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (obj_key, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
(
(
_,
cur_phase_idx,
),
cur_phase,
) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
cur_phase,
embed0,
)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.weights_stream_mgr.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
weights.compute_phases[0], weights.compute_phases[0],
embed0, pre_infer_out.embed0,
) )
y_out = self.infer_self_attn( y_out = self.infer_self_attn(
weights.compute_phases[1], weights.compute_phases[1],
grid_sizes, pre_infer_out.grid_sizes,
x, x,
seq_lens, pre_infer_out.seq_lens,
freqs, pre_infer_out.freqs,
shift_msa, shift_msa,
scale_msa, scale_msa,
) )
x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, pre_infer_out.context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, grid_sizes, audio_dit_blocks) x = self.post_process(x, y, c_gate_msa, pre_infer_out)
return x return x
def infer_modulation(self, weights, embed0): def infer_modulation(self, weights, embed0):
...@@ -531,19 +287,12 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -531,19 +287,12 @@ class WanTransformerInfer(BaseTransformerInfer):
return y return y
def post_process(self, x, y, c_gate_msa, grid_sizes, audio_dit_blocks=None): def post_process(self, x, y, c_gate_msa, pre_infer_out):
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else: else:
x.add_(y * c_gate_msa.squeeze()) x.add_(y * c_gate_msa.squeeze())
# Apply audio_dit if available
if audio_dit_blocks is not None and hasattr(self, "block_idx"):
for ipa_out in audio_dit_blocks:
if self.block_idx in ipa_out:
cur_modify = ipa_out[self.block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
if self.clean_cuda_cache: if self.clean_cuda_cache:
del y, c_gate_msa del y, c_gate_msa
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -18,6 +18,9 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import ...@@ -18,6 +18,9 @@ from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import
WanTransformerInferTaylorCaching, WanTransformerInferTaylorCaching,
WanTransformerInferTeaCaching, WanTransformerInferTeaCaching,
) )
from lightx2v.models.networks.wan.infer.offload.transformer_infer import (
WanOffloadTransformerInfer,
)
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.transformer_infer import ( from lightx2v.models.networks.wan.infer.transformer_infer import (
...@@ -64,7 +67,12 @@ class WanModel: ...@@ -64,7 +67,12 @@ class WanModel:
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme) self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True self.config.use_gguf = True
else: else:
self.dit_quantized_ckpt = find_hf_model_path(config, self.model_path, "dit_quantized_ckpt", subdir=dit_quant_scheme) self.dit_quantized_ckpt = find_hf_model_path(
config,
self.model_path,
"dit_quantized_ckpt",
subdir=dit_quant_scheme,
)
quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json") quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json")
if os.path.exists(quant_config_path): if os.path.exists(quant_config_path):
with open(quant_config_path, "r") as f: with open(quant_config_path, "r") as f:
...@@ -90,7 +98,7 @@ class WanModel: ...@@ -90,7 +98,7 @@ class WanModel:
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer self.transformer_infer_class = WanTransformerInfer if not self.cpu_offload else WanOffloadTransformerInfer
elif self.config["feature_caching"] == "Tea": elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "TaylorSeer": elif self.config["feature_caching"] == "TaylorSeer":
...@@ -158,7 +166,11 @@ class WanModel: ...@@ -158,7 +166,11 @@ class WanModel:
with safe_open(safetensor_path, framework="pt") as f: with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}") logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]: if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
...@@ -176,7 +188,11 @@ class WanModel: ...@@ -176,7 +188,11 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f: with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]: if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment