"doc/vscode:/vscode.git/clone" did not exist on "063c5489b349fe2d7f786c8196acc2bae5b24ce6"
Commit abeb9bc8 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Support vace offload and recon offload. (#245)

parent 87343386
...@@ -10,21 +10,27 @@ from loguru import logger ...@@ -10,21 +10,27 @@ from loguru import logger
class WeightAsyncStreamManager(object): class WeightAsyncStreamManager(object):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1): def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
self.active_weights = [None for _ in range(3)] self.init(blocks_num, phases_num, offload_ratio)
self.compute_stream = torch.cuda.Stream(priority=-1) self.compute_stream = torch.cuda.Stream(priority=-1)
self.cpu_load_stream = torch.cuda.Stream(priority=0) self.cpu_load_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=0) self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.offload_block_num = int(offload_ratio * blocks_num)
def init(self, blocks_num, phases_num, offload_ratio):
if hasattr(self, "active_weights"):
del self.active_weights[:]
self.active_weights = [None for _ in range(3)]
self.blocks_num = blocks_num
self.phases_num = phases_num self.phases_num = phases_num
self.block_nums = blocks_num self.offload_ratio = offload_ratio
self.offload_phases_num = blocks_num * phases_num * offload_ratio self.offload_blocks_num = int(self.offload_ratio * self.blocks_num)
self.offload_phases_num = self.blocks_num * self.phases_num * self.offload_ratio
def prefetch_weights(self, block_idx, blocks_weights): def prefetch_weights(self, block_idx, blocks_weights):
with torch.cuda.stream(self.cuda_load_stream): with torch.cuda.stream(self.cuda_load_stream):
self.active_weights[2] = blocks_weights[block_idx] self.active_weights[2] = blocks_weights[block_idx]
self.active_weights[2].to_cuda_async() self.active_weights[2].to_cuda_async()
with torch.cuda.stream(self.cpu_load_stream): with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_block_num: if block_idx < self.offload_blocks_num:
if self.active_weights[1] is not None: if self.active_weights[1] is not None:
self.active_weights[1].to_cpu_async() self.active_weights[1].to_cpu_async()
...@@ -130,7 +136,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -130,7 +136,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if next_block_idx < 0: if next_block_idx < 0:
next_block_idx = 0 next_block_idx = 0
if next_block_idx == self.block_nums: if next_block_idx == self.blocks_num:
return return
if self.offload_gra == "phase": if self.offload_gra == "phase":
...@@ -175,7 +181,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -175,7 +181,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self.pin_memory_buffer.push(block_idx, block) self.pin_memory_buffer.push(block_idx, block)
block_idx += 1 block_idx += 1
if block_idx == self.block_nums: if block_idx == self.blocks_num:
break break
def prefetch_weights_from_disk(self, blocks): def prefetch_weights_from_disk(self, blocks):
...@@ -217,7 +223,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -217,7 +223,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self.active_weights[2] = (obj_key, block) self.active_weights[2] = (obj_key, block)
with torch.cuda.stream(self.cpu_load_stream): with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_block_num: if block_idx < self.offload_blocks_num:
if self.active_weights[1] is not None: if self.active_weights[1] is not None:
old_key, old_block = self.active_weights[1] old_key, old_block = self.active_weights[1]
if self.pin_memory_buffer.exists(old_key): if self.pin_memory_buffer.exists(old_key):
......
...@@ -95,6 +95,12 @@ class MMWeight(MMWeightTemplate): ...@@ -95,6 +95,12 @@ class MMWeight(MMWeightTemplate):
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
def _calculate_size(self):
if self.bias is not None:
return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()
return self.weight.numel() * self.weight.element_size()
def apply(self, input_tensor): def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype dtype = input_tensor.dtype
......
...@@ -4,8 +4,7 @@ from lightx2v.common.offload.manager import ( ...@@ -4,8 +4,7 @@ from lightx2v.common.offload.manager import (
LazyWeightAsyncStreamManager, LazyWeightAsyncStreamManager,
WeightAsyncStreamManager, WeightAsyncStreamManager,
) )
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
from ..transformer_infer import WanTransformerInfer
class WanOffloadTransformerInfer(WanTransformerInfer): class WanOffloadTransformerInfer(WanTransformerInfer):
...@@ -13,20 +12,31 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -13,20 +12,31 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
super().__init__(config) super().__init__(config)
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] self.offload_ratio = self.config["offload_ratio"]
else: else:
offload_ratio = 1 self.offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block") offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block": if offload_granularity == "block":
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_offload self.infer_func = self.infer_with_blocks_offload
else: else:
self.infer_func = self.infer_with_lazy_offload self.infer_func = self.infer_with_blocks_lazy_offload
elif offload_granularity == "phase": elif offload_granularity == "phase":
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_phases_offload self.infer_func = self.infer_with_phases_offload
else: else:
self.infer_func = self.infer_with_phases_lazy_offload self.infer_func = self.infer_with_phases_lazy_offload
self.phase_params = {
"shift_msa": None,
"scale_msa": None,
"gate_msa": None,
"c_shift_msa": None,
"c_scale_msa": None,
"c_gate_msa": None,
"y_out": None,
"attn_out": None,
"y": None,
}
elif offload_granularity == "model": elif offload_granularity == "model":
self.infer_func = self.infer_without_offload self.infer_func = self.infer_without_offload
...@@ -34,168 +44,201 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -34,168 +44,201 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager( self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num, blocks_num=self.blocks_num,
offload_ratio=offload_ratio, offload_ratio=self.offload_ratio,
phases_num=self.phases_num, phases_num=self.phases_num,
) )
else: else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager( self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num, blocks_num=self.blocks_num,
offload_ratio=offload_ratio, offload_ratio=self.offload_ratio,
phases_num=self.phases_num, phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2), num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2), max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity, offload_gra=offload_granularity,
) )
def infer_with_offload(self, weights, x, pre_infer_out): def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(self.blocks_num): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0] self.weights_stream_mgr.active_weights[0] = blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda() self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < self.blocks_num - 1: if block_idx < len(blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks) self.weights_stream_mgr.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out) x = self.infer_block(blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
return x return x
def infer_with_lazy_offload(self, weights, x, pre_infer_out): def infer_with_blocks_lazy_offload(self, blocks, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks) self.weights_stream_mgr.prefetch_weights_from_disk(blocks)
for block_idx in range(self.blocks_num): for block_idx in range(len(blocks)):
if block_idx == 0: if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx) block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx)
block.to_cuda() block.to_cuda()
self.weights_stream_mgr.active_weights[0] = (block_idx, block) self.weights_stream_mgr.active_weights[0] = (block_idx, block)
if block_idx < self.blocks_num - 1: if block_idx < len(blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks) self.weights_stream_mgr.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out) x = self.infer_block(blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1: if block_idx == len(blocks) - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front() self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights.blocks) self.weights_stream_mgr._async_prefetch_block(blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context del (
pre_infer_out.grid_sizes,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x return x
def infer_with_phases_offload(self, weights, x, pre_infer_out): def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(weights.blocks_num): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
for phase_idx in range(self.phases_num): x = self.infer_phases(block_idx, blocks, x, pre_infer_out, False)
if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, pre_infer_out.embed0)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, pre_infer_out.context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
if self.clean_cuda_cache: if self.clean_cuda_cache:
del attn_out, y_out, y del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.clean_cuda_cache: if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa self.clear_offload_params(pre_infer_out)
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context
torch.cuda.empty_cache()
return x return x
def infer_with_phases_lazy_offload(self, weights, x, pre_infer_out): def infer_with_phases_lazy_offload(self, blocks, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks) self.weights_stream_mgr.prefetch_weights_from_disk(blocks)
for block_idx in range(weights.blocks_num): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
for phase_idx in range(self.weights_stream_mgr.phases_num): x = self.infer_phases(block_idx, blocks, x, pre_infer_out, True)
if block_idx == 0 and phase_idx == 0:
self.weights_stream_mgr._async_prefetch_block(blocks)
if self.clean_cuda_cache:
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out)
return x
def infer_phases(self, block_idx, blocks, x, pre_infer_out, lazy):
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
if lazy:
obj_key = (block_idx, phase_idx) obj_key = (block_idx, phase_idx)
phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key) phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key)
phase.to_cuda() phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (obj_key, phase) self.weights_stream_mgr.active_weights[0] = (obj_key, phase)
else:
phase = blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
( x = self.infer_phase(self.weights_stream_mgr.active_weights[0], x, pre_infer_out)
(
_,
cur_phase_idx,
),
cur_phase,
) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, pre_infer_out.embed0)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, pre_infer_out.context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.weights_stream_mgr.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache: is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1
del attn_out, y_out, y if not is_last_phase:
torch.cuda.empty_cache() next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, blocks)
if self.clean_cuda_cache: self.weights_stream_mgr.swap_phases()
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context return x
torch.cuda.empty_cache()
def infer_phase(self, active_weight, x, pre_infer_out):
if not self.config.get("lazy_load"):
cur_phase_idx, cur_phase = active_weight
else:
(_, cur_phase_idx), cur_phase = active_weight
if cur_phase_idx == 0:
if hasattr(cur_phase, "before_proj"):
x = cur_phase.before_proj.apply(x) + pre_infer_out.x
(
self.phase_params["shift_msa"],
self.phase_params["scale_msa"],
self.phase_params["gate_msa"],
self.phase_params["c_shift_msa"],
self.phase_params["c_scale_msa"],
self.phase_params["c_gate_msa"],
) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0)
self.phase_params["y_out"] = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
self.phase_params["shift_msa"],
self.phase_params["scale_msa"],
)
elif cur_phase_idx == 1:
x, self.phase_params["attn_out"] = self.infer_cross_attn(
cur_phase,
x,
pre_infer_out.context,
self.phase_params["y_out"],
self.phase_params["gate_msa"],
)
elif cur_phase_idx == 2:
self.phase_params["y"] = self.infer_ffn(
cur_phase,
x,
self.phase_params["attn_out"],
self.phase_params["c_shift_msa"],
self.phase_params["c_scale_msa"],
)
x = self.post_process(
x,
self.phase_params["y"],
self.phase_params["c_gate_msa"],
pre_infer_out,
)
if hasattr(cur_phase, "after_proj"):
pre_infer_out.adapter_output["hints"].append(cur_phase.after_proj.apply(x))
return x return x
def clear_offload_params(self, pre_infer_out):
del (
self.phase_params["shift_msa"],
self.phase_params["scale_msa"],
self.phase_params["gate_msa"],
)
del (
self.phase_params["c_shift_msa"],
self.phase_params["c_scale_msa"],
self.phase_params["c_gate_msa"],
)
del (
pre_infer_out.grid_sizes,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
)
torch.cuda.empty_cache()
...@@ -14,7 +14,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -14,7 +14,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.task = config.task self.task = config.task
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config.num_layers self.blocks_num = config.num_layers
self.phases_num = 4 self.phases_num = 3
self.num_heads = config.num_heads self.num_heads = config.num_heads
self.head_dim = config.dim // config.num_heads self.head_dim = config.dim // config.num_heads
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
...@@ -49,11 +49,11 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -49,11 +49,11 @@ class WanTransformerInfer(BaseTransformerInfer):
return freqs_i return freqs_i
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
x = self.infer_main_blocks(weights, pre_infer_out) x = self.infer_main_blocks(weights.blocks, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed) return self.infer_non_blocks(weights, x, pre_infer_out.embed)
def infer_main_blocks(self, weights, pre_infer_out): def infer_main_blocks(self, blocks, pre_infer_out):
x = self.infer_func(weights, pre_infer_out.x, pre_infer_out) x = self.infer_func(blocks, pre_infer_out.x, pre_infer_out)
return x return x
def infer_non_blocks(self, weights, x, e): def infer_non_blocks(self, weights, x, e):
...@@ -80,19 +80,22 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -80,19 +80,22 @@ class WanTransformerInfer(BaseTransformerInfer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x return x
def infer_without_offload(self, weights, x, pre_infer_out): def infer_without_offload(self, blocks, x, pre_infer_out):
for block_idx in range(self.blocks_num): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out) x = self.infer_block(blocks[block_idx], x, pre_infer_out)
return x return x
def infer_block(self, weights, x, pre_infer_out): def infer_block(self, block, x, pre_infer_out):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation( if hasattr(block.compute_phases[0], "before_proj"):
weights.compute_phases[0], x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process(
block.compute_phases[0].modulation,
pre_infer_out.embed0, pre_infer_out.embed0,
) )
y_out = self.infer_self_attn( y_out = self.infer_self_attn(
weights.compute_phases[1], block.compute_phases[0],
pre_infer_out.grid_sizes, pre_infer_out.grid_sizes,
x, x,
pre_infer_out.seq_lens, pre_infer_out.seq_lens,
...@@ -100,18 +103,21 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -100,18 +103,21 @@ class WanTransformerInfer(BaseTransformerInfer):
shift_msa, shift_msa,
scale_msa, scale_msa,
) )
x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, pre_infer_out.context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(block.compute_phases[1], x, pre_infer_out.context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out) x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if hasattr(block.compute_phases[2], "after_proj"):
pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x))
return x return x
def infer_modulation(self, weights, embed0): def pre_process(self, modulation, embed0):
if embed0.dim() == 3 and embed0.shape[2] == 1: if embed0.dim() == 3 and embed0.shape[2] == 1:
modulation = weights.modulation.tensor.unsqueeze(2) modulation = modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1) embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0] shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0]
else: else:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (modulation.tensor + embed0).chunk(6, dim=1)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del embed0 del embed0
...@@ -119,15 +125,15 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -119,15 +125,15 @@ class WanTransformerInfer(BaseTransformerInfer):
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): def infer_self_attn(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"): if hasattr(phase, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * weights.smooth_norm1_weight.tensor norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * weights.smooth_norm1_bias.tensor norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor
else: else:
norm1_weight = 1 + scale_msa.squeeze() norm1_weight = 1 + scale_msa.squeeze()
norm1_bias = shift_msa.squeeze() norm1_bias = shift_msa.squeeze()
norm1_out = weights.norm1.apply(x) norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype) norm1_out = norm1_out.to(self.sensitive_layer_dtype)
...@@ -139,9 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -139,9 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer):
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d) q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d)
k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d) k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d)
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
freqs_i = self.compute_freqs(q, grid_sizes, freqs) freqs_i = self.compute_freqs(q, grid_sizes, freqs)
...@@ -156,18 +162,18 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -156,18 +162,18 @@ class WanTransformerInfer(BaseTransformerInfer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
attn_out = weights.self_attn_1_parallel.apply( attn_out = phase.self_attn_1_parallel.apply(
q=q, q=q,
k=k, k=k,
v=v, v=v,
img_qkv_len=q.shape[0], img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q, cu_seqlens_qkv=cu_seqlens_q,
attention_module=weights.self_attn_1, attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
else: else:
attn_out = weights.self_attn_1.apply( attn_out = phase.self_attn_1.apply(
q=q, q=q,
k=k, k=k,
v=v, v=v,
...@@ -179,7 +185,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -179,7 +185,7 @@ class WanTransformerInfer(BaseTransformerInfer):
mask_map=self.mask_map, mask_map=self.mask_map,
) )
y = weights.self_attn_o.apply(attn_out) y = phase.self_attn_o.apply(attn_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del q, k, v, attn_out del q, k, v, attn_out
...@@ -187,13 +193,13 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -187,13 +193,13 @@ class WanTransformerInfer(BaseTransformerInfer):
return y return y
def infer_cross_attn(self, weights, x, context, y_out, gate_msa): def infer_cross_attn(self, phase, x, context, y_out, gate_msa):
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze() x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze()
else: else:
x.add_(y_out * gate_msa.squeeze()) x.add_(y_out * gate_msa.squeeze())
norm3_out = weights.norm3.apply(x) norm3_out = phase.norm3.apply(x)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
...@@ -207,14 +213,14 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -207,14 +213,14 @@ class WanTransformerInfer(BaseTransformerInfer):
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d) q = phase.cross_attn_norm_q.apply(phase.cross_attn_q.apply(norm3_out)).view(-1, n, d)
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d) k = phase.cross_attn_norm_k.apply(phase.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d) v = phase.cross_attn_v.apply(context).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q, q,
k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device),
) )
attn_out = weights.cross_attn_1.apply( attn_out = phase.cross_attn_1.apply(
q=q, q=q,
k=k, k=k,
v=v, v=v,
...@@ -226,14 +232,14 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -226,14 +232,14 @@ class WanTransformerInfer(BaseTransformerInfer):
) )
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True) and context_img is not None: if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q, q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
) )
img_attn_out = weights.cross_attn_2.apply( img_attn_out = phase.cross_attn_2.apply(
q=q, q=q,
k=k_img, k=k_img,
v=v_img, v=v_img,
...@@ -249,42 +255,42 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -249,42 +255,42 @@ class WanTransformerInfer(BaseTransformerInfer):
del k_img, v_img, img_attn_out del k_img, v_img, img_attn_out
torch.cuda.empty_cache() torch.cuda.empty_cache()
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = phase.cross_attn_o.apply(attn_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x, attn_out return x, attn_out
def infer_ffn(self, weights, x, attn_out, c_shift_msa, c_scale_msa): def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out) x.add_(attn_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del attn_out del attn_out
torch.cuda.empty_cache() torch.cuda.empty_cache()
if hasattr(weights, "smooth_norm2_weight"): if hasattr(phase, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze()) * weights.smooth_norm2_weight.tensor norm2_weight = (1 + c_scale_msa.squeeze()) * phase.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze() * weights.smooth_norm2_bias.tensor norm2_bias = c_shift_msa.squeeze() * phase.smooth_norm2_bias.tensor
else: else:
norm2_weight = 1 + c_scale_msa.squeeze() norm2_weight = 1 + c_scale_msa.squeeze()
norm2_bias = c_shift_msa.squeeze() norm2_bias = c_shift_msa.squeeze()
norm2_out = weights.norm2.apply(x) norm2_out = phase.norm2.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype) norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out.mul_(norm2_weight).add_(norm2_bias) norm2_out.mul_(norm2_weight).add_(norm2_bias)
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.infer_dtype) norm2_out = norm2_out.to(self.infer_dtype)
y = weights.ffn_0.apply(norm2_out) y = phase.ffn_0.apply(norm2_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del norm2_out, x, norm2_weight, norm2_bias del norm2_out, x, norm2_weight, norm2_bias
torch.cuda.empty_cache() torch.cuda.empty_cache()
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
if self.clean_cuda_cache: if self.clean_cuda_cache:
torch.cuda.empty_cache() torch.cuda.empty_cache()
y = weights.ffn_2.apply(y) y = phase.ffn_2.apply(y)
return y return y
......
...@@ -5,41 +5,33 @@ from lightx2v.utils.envs import * ...@@ -5,41 +5,33 @@ from lightx2v.utils.envs import *
class WanVaceTransformerInfer(WanOffloadTransformerInfer): class WanVaceTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.vace_block_nums = len(self.config.vace_layers) self.vace_blocks_num = len(self.config.vace_layers)
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)} self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)}
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
pre_infer_out.adapter_output["hints"] = self.infer_vace(weights, pre_infer_out) pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context)
x = self.infer_main_blocks(weights, pre_infer_out) self.infer_vace_blocks(weights.vace_blocks, pre_infer_out)
x = self.infer_main_blocks(weights.blocks, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed) return self.infer_non_blocks(weights, x, pre_infer_out.embed)
def infer_vace(self, weights, pre_infer_out): def vace_pre_process(self, patch_embedding, vace_context):
c = weights.vace_patch_embedding.apply(pre_infer_out.vace_context.unsqueeze(0).to(self.sensitive_layer_dtype)) c = patch_embedding.apply(vace_context.unsqueeze(0).to(self.sensitive_layer_dtype))
c = c.flatten(2).transpose(1, 2).contiguous().squeeze(0) c = c.flatten(2).transpose(1, 2).contiguous().squeeze(0)
return c
def infer_vace_blocks(self, vace_blocks, pre_infer_out):
pre_infer_out.adapter_output["hints"] = []
self.infer_state = "vace" self.infer_state = "vace"
hints = [] if hasattr(self, "weights_stream_mgr"):
self.weights_stream_mgr.init(self.vace_blocks_num, self.phases_num, self.offload_ratio)
for i in range(self.vace_block_nums): self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out)
c, c_skip = self.infer_vace_block(weights.vace_blocks[i], c, pre_infer_out.x, pre_infer_out)
hints.append(c_skip)
self.infer_state = "base" self.infer_state = "base"
return hints if hasattr(self, "weights_stream_mgr"):
self.weights_stream_mgr.init(self.blocks_num, self.phases_num, self.offload_ratio)
def infer_vace_block(self, weights, c, x, pre_infer_out):
if hasattr(weights, "before_proj"):
c = weights.before_proj.apply(c) + x
c = self.infer_block(weights, c, pre_infer_out)
c_skip = weights.after_proj.apply(c)
return c, c_skip
def post_process(self, x, y, c_gate_msa, pre_infer_out): def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out) x = super().post_process(x, y, c_gate_msa, pre_infer_out)
if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping: if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping:
hint_idx = self.vace_blocks_mapping[self.block_idx] hint_idx = self.vace_blocks_mapping[self.block_idx]
x = x + pre_infer_out.adapter_output["hints"][hint_idx] * pre_infer_out.adapter_output.get("context_scale", 1.0) x = x + pre_infer_out.adapter_output["hints"][hint_idx] * pre_infer_out.adapter_output.get("context_scale", 1.0)
return x return x
...@@ -27,7 +27,7 @@ class WanTransformerWeights(WeightModule): ...@@ -27,7 +27,7 @@ class WanTransformerWeights(WeightModule):
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
# post blocks weights # non blocks weights
self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]()) self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]())
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")) self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation")) self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
...@@ -67,15 +67,6 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -67,15 +67,6 @@ class WanTransformerAttentionBlock(WeightModule):
self.compute_phases = WeightModuleList( self.compute_phases = WeightModuleList(
[ [
WanModulation(
block_index,
block_prefix,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanSelfAttention( WanSelfAttention(
block_index, block_index,
block_prefix, block_prefix,
...@@ -109,7 +100,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -109,7 +100,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.add_module("compute_phases", self.compute_phases) self.add_module("compute_phases", self.compute_phases)
class WanModulation(WeightModule): class WanSelfAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
...@@ -131,20 +122,6 @@ class WanModulation(WeightModule): ...@@ -131,20 +122,6 @@ class WanModulation(WeightModule):
), ),
) )
class WanSelfAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module( self.add_module(
"norm1", "norm1",
LN_WEIGHT_REGISTER["Default"](), LN_WEIGHT_REGISTER["Default"](),
......
...@@ -9,8 +9,6 @@ from lightx2v.utils.registry_factory import ( ...@@ -9,8 +9,6 @@ from lightx2v.utils.registry_factory import (
) )
# "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
# {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
class WanVaceTransformerWeights(WanTransformerWeights): class WanVaceTransformerWeights(WanTransformerWeights):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -44,7 +42,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock): ...@@ -44,7 +42,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
def __init__(self, base_block_idx, block_index, task, mm_type, config, block_prefix): def __init__(self, base_block_idx, block_index, task, mm_type, config, block_prefix):
super().__init__(block_index, task, mm_type, config, block_prefix) super().__init__(block_index, task, mm_type, config, block_prefix)
if base_block_idx == 0: if base_block_idx == 0:
self.add_module( self.compute_phases[0].add_module(
"before_proj", "before_proj",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.before_proj.weight", f"{block_prefix}.{self.block_index}.before_proj.weight",
...@@ -53,7 +51,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock): ...@@ -53,7 +51,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module(
self.compute_phases[-1].add_module(
"after_proj", "after_proj",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.after_proj.weight", f"{block_prefix}.{self.block_index}.after_proj.weight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment