"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "f003cd354868a4c442adb710a81678fcaa47b0eb"
Commit abeb9bc8 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Support vace offload and recon offload. (#245)

parent 87343386
......@@ -10,21 +10,27 @@ from loguru import logger
class WeightAsyncStreamManager(object):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
self.active_weights = [None for _ in range(3)]
self.init(blocks_num, phases_num, offload_ratio)
self.compute_stream = torch.cuda.Stream(priority=-1)
self.cpu_load_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.offload_block_num = int(offload_ratio * blocks_num)
def init(self, blocks_num, phases_num, offload_ratio):
if hasattr(self, "active_weights"):
del self.active_weights[:]
self.active_weights = [None for _ in range(3)]
self.blocks_num = blocks_num
self.phases_num = phases_num
self.block_nums = blocks_num
self.offload_phases_num = blocks_num * phases_num * offload_ratio
self.offload_ratio = offload_ratio
self.offload_blocks_num = int(self.offload_ratio * self.blocks_num)
self.offload_phases_num = self.blocks_num * self.phases_num * self.offload_ratio
def prefetch_weights(self, block_idx, blocks_weights):
with torch.cuda.stream(self.cuda_load_stream):
self.active_weights[2] = blocks_weights[block_idx]
self.active_weights[2].to_cuda_async()
with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_block_num:
if block_idx < self.offload_blocks_num:
if self.active_weights[1] is not None:
self.active_weights[1].to_cpu_async()
......@@ -130,7 +136,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if next_block_idx < 0:
next_block_idx = 0
if next_block_idx == self.block_nums:
if next_block_idx == self.blocks_num:
return
if self.offload_gra == "phase":
......@@ -175,7 +181,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self.pin_memory_buffer.push(block_idx, block)
block_idx += 1
if block_idx == self.block_nums:
if block_idx == self.blocks_num:
break
def prefetch_weights_from_disk(self, blocks):
......@@ -217,7 +223,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self.active_weights[2] = (obj_key, block)
with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_block_num:
if block_idx < self.offload_blocks_num:
if self.active_weights[1] is not None:
old_key, old_block = self.active_weights[1]
if self.pin_memory_buffer.exists(old_key):
......
......@@ -95,6 +95,12 @@ class MMWeight(MMWeightTemplate):
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
def _calculate_size(self):
if self.bias is not None:
return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()
return self.weight.numel() * self.weight.element_size()
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
......
......@@ -4,8 +4,7 @@ from lightx2v.common.offload.manager import (
LazyWeightAsyncStreamManager,
WeightAsyncStreamManager,
)
from ..transformer_infer import WanTransformerInfer
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
class WanOffloadTransformerInfer(WanTransformerInfer):
......@@ -13,20 +12,31 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
super().__init__(config)
if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
self.offload_ratio = self.config["offload_ratio"]
else:
offload_ratio = 1
self.offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_offload
self.infer_func = self.infer_with_blocks_offload
else:
self.infer_func = self.infer_with_lazy_offload
self.infer_func = self.infer_with_blocks_lazy_offload
elif offload_granularity == "phase":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_phases_offload
else:
self.infer_func = self.infer_with_phases_lazy_offload
self.phase_params = {
"shift_msa": None,
"scale_msa": None,
"gate_msa": None,
"c_shift_msa": None,
"c_scale_msa": None,
"c_gate_msa": None,
"y_out": None,
"attn_out": None,
"y": None,
}
elif offload_granularity == "model":
self.infer_func = self.infer_without_offload
......@@ -34,168 +44,201 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
offload_ratio=self.offload_ratio,
phases_num=self.phases_num,
)
else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=offload_ratio,
offload_ratio=self.offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
def infer_with_offload(self, weights, x, pre_infer_out):
for block_idx in range(self.blocks_num):
def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0] = blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
if block_idx < len(blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out)
x = self.infer_block(blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.swap_weights()
return x
def infer_with_lazy_offload(self, weights, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
def infer_with_blocks_lazy_offload(self, blocks, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(blocks)
for block_idx in range(self.blocks_num):
for block_idx in range(len(blocks)):
if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx)
block.to_cuda()
self.weights_stream_mgr.active_weights[0] = (block_idx, block)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
if block_idx < len(blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out)
x = self.infer_block(blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1:
if block_idx == len(blocks) - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
self.weights_stream_mgr._async_prefetch_block(blocks)
if self.clean_cuda_cache:
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context
del (
pre_infer_out.grid_sizes,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
)
torch.cuda.empty_cache()
return x
def infer_with_phases_offload(self, weights, x, pre_infer_out):
for block_idx in range(weights.blocks_num):
def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, pre_infer_out.embed0)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, pre_infer_out.context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, False)
if self.clean_cuda_cache:
del attn_out, y_out, y
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context
torch.cuda.empty_cache()
self.clear_offload_params(pre_infer_out)
return x
def infer_with_phases_lazy_offload(self, weights, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
def infer_with_phases_lazy_offload(self, blocks, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(blocks)
for block_idx in range(weights.blocks_num):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0:
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, True)
self.weights_stream_mgr._async_prefetch_block(blocks)
if self.clean_cuda_cache:
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out)
return x
def infer_phases(self, block_idx, blocks, x, pre_infer_out, lazy):
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
if lazy:
obj_key = (block_idx, phase_idx)
phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key)
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (obj_key, phase)
else:
phase = blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
(
(
_,
cur_phase_idx,
),
cur_phase,
) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, pre_infer_out.embed0)
elif cur_phase_idx == 1:
y_out = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2:
x, attn_out = self.infer_cross_attn(cur_phase, x, pre_infer_out.context, y_out, gate_msa)
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.weights_stream_mgr.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, weights.blocks)
self.weights_stream_mgr.swap_phases()
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_phase(self.weights_stream_mgr.active_weights[0], x, pre_infer_out)
if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, blocks)
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del pre_infer_out.grid_sizes, pre_infer_out.embed0, pre_infer_out.seq_lens, pre_infer_out.freqs, pre_infer_out.context
torch.cuda.empty_cache()
self.weights_stream_mgr.swap_phases()
return x
def infer_phase(self, active_weight, x, pre_infer_out):
if not self.config.get("lazy_load"):
cur_phase_idx, cur_phase = active_weight
else:
(_, cur_phase_idx), cur_phase = active_weight
if cur_phase_idx == 0:
if hasattr(cur_phase, "before_proj"):
x = cur_phase.before_proj.apply(x) + pre_infer_out.x
(
self.phase_params["shift_msa"],
self.phase_params["scale_msa"],
self.phase_params["gate_msa"],
self.phase_params["c_shift_msa"],
self.phase_params["c_scale_msa"],
self.phase_params["c_gate_msa"],
) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0)
self.phase_params["y_out"] = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
self.phase_params["shift_msa"],
self.phase_params["scale_msa"],
)
elif cur_phase_idx == 1:
x, self.phase_params["attn_out"] = self.infer_cross_attn(
cur_phase,
x,
pre_infer_out.context,
self.phase_params["y_out"],
self.phase_params["gate_msa"],
)
elif cur_phase_idx == 2:
self.phase_params["y"] = self.infer_ffn(
cur_phase,
x,
self.phase_params["attn_out"],
self.phase_params["c_shift_msa"],
self.phase_params["c_scale_msa"],
)
x = self.post_process(
x,
self.phase_params["y"],
self.phase_params["c_gate_msa"],
pre_infer_out,
)
if hasattr(cur_phase, "after_proj"):
pre_infer_out.adapter_output["hints"].append(cur_phase.after_proj.apply(x))
return x
def clear_offload_params(self, pre_infer_out):
del (
self.phase_params["shift_msa"],
self.phase_params["scale_msa"],
self.phase_params["gate_msa"],
)
del (
self.phase_params["c_shift_msa"],
self.phase_params["c_scale_msa"],
self.phase_params["c_gate_msa"],
)
del (
pre_infer_out.grid_sizes,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
)
torch.cuda.empty_cache()
......@@ -14,7 +14,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.task = config.task
self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config.num_layers
self.phases_num = 4
self.phases_num = 3
self.num_heads = config.num_heads
self.head_dim = config.dim // config.num_heads
self.window_size = config.get("window_size", (-1, -1))
......@@ -49,11 +49,11 @@ class WanTransformerInfer(BaseTransformerInfer):
return freqs_i
def infer(self, weights, pre_infer_out):
x = self.infer_main_blocks(weights, pre_infer_out)
x = self.infer_main_blocks(weights.blocks, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed)
def infer_main_blocks(self, weights, pre_infer_out):
x = self.infer_func(weights, pre_infer_out.x, pre_infer_out)
def infer_main_blocks(self, blocks, pre_infer_out):
x = self.infer_func(blocks, pre_infer_out.x, pre_infer_out)
return x
def infer_non_blocks(self, weights, x, e):
......@@ -80,19 +80,22 @@ class WanTransformerInfer(BaseTransformerInfer):
torch.cuda.empty_cache()
return x
def infer_without_offload(self, weights, x, pre_infer_out):
for block_idx in range(self.blocks_num):
def infer_without_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_block(weights.blocks[block_idx], x, pre_infer_out)
x = self.infer_block(blocks[block_idx], x, pre_infer_out)
return x
def infer_block(self, weights, x, pre_infer_out):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
weights.compute_phases[0],
def infer_block(self, block, x, pre_infer_out):
if hasattr(block.compute_phases[0], "before_proj"):
x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process(
block.compute_phases[0].modulation,
pre_infer_out.embed0,
)
y_out = self.infer_self_attn(
weights.compute_phases[1],
block.compute_phases[0],
pre_infer_out.grid_sizes,
x,
pre_infer_out.seq_lens,
......@@ -100,18 +103,21 @@ class WanTransformerInfer(BaseTransformerInfer):
shift_msa,
scale_msa,
)
x, attn_out = self.infer_cross_attn(weights.compute_phases[2], x, pre_infer_out.context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x, attn_out = self.infer_cross_attn(block.compute_phases[1], x, pre_infer_out.context, y_out, gate_msa)
y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if hasattr(block.compute_phases[2], "after_proj"):
pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x))
return x
def infer_modulation(self, weights, embed0):
def pre_process(self, modulation, embed0):
if embed0.dim() == 3 and embed0.shape[2] == 1:
modulation = weights.modulation.tensor.unsqueeze(2)
modulation = modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0]
else:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (modulation.tensor + embed0).chunk(6, dim=1)
if self.clean_cuda_cache:
del embed0
......@@ -119,15 +125,15 @@ class WanTransformerInfer(BaseTransformerInfer):
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * weights.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * weights.smooth_norm1_bias.tensor
def infer_self_attn(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(phase, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor
else:
norm1_weight = 1 + scale_msa.squeeze()
norm1_bias = shift_msa.squeeze()
norm1_out = weights.norm1.apply(x)
norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
......@@ -139,9 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer):
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d)
k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d)
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d)
k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d)
v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
freqs_i = self.compute_freqs(q, grid_sizes, freqs)
......@@ -156,18 +162,18 @@ class WanTransformerInfer(BaseTransformerInfer):
torch.cuda.empty_cache()
if self.config["seq_parallel"]:
attn_out = weights.self_attn_1_parallel.apply(
attn_out = phase.self_attn_1_parallel.apply(
q=q,
k=k,
v=v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q,
attention_module=weights.self_attn_1,
attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group,
model_cls=self.config["model_cls"],
)
else:
attn_out = weights.self_attn_1.apply(
attn_out = phase.self_attn_1.apply(
q=q,
k=k,
v=v,
......@@ -179,7 +185,7 @@ class WanTransformerInfer(BaseTransformerInfer):
mask_map=self.mask_map,
)
y = weights.self_attn_o.apply(attn_out)
y = phase.self_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, attn_out
......@@ -187,13 +193,13 @@ class WanTransformerInfer(BaseTransformerInfer):
return y
def infer_cross_attn(self, weights, x, context, y_out, gate_msa):
def infer_cross_attn(self, phase, x, context, y_out, gate_msa):
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze()
else:
x.add_(y_out * gate_msa.squeeze())
norm3_out = weights.norm3.apply(x)
norm3_out = phase.norm3.apply(x)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
context_img = context[:257]
context = context[257:]
......@@ -207,14 +213,14 @@ class WanTransformerInfer(BaseTransformerInfer):
n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d)
q = phase.cross_attn_norm_q.apply(phase.cross_attn_q.apply(norm3_out)).view(-1, n, d)
k = phase.cross_attn_norm_k.apply(phase.cross_attn_k.apply(context)).view(-1, n, d)
v = phase.cross_attn_v.apply(context).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q,
k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device),
)
attn_out = weights.cross_attn_1.apply(
attn_out = phase.cross_attn_1.apply(
q=q,
k=k,
v=v,
......@@ -226,14 +232,14 @@ class WanTransformerInfer(BaseTransformerInfer):
)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
)
img_attn_out = weights.cross_attn_2.apply(
img_attn_out = phase.cross_attn_2.apply(
q=q,
k=k_img,
v=v_img,
......@@ -249,42 +255,42 @@ class WanTransformerInfer(BaseTransformerInfer):
del k_img, v_img, img_attn_out
torch.cuda.empty_cache()
attn_out = weights.cross_attn_o.apply(attn_out)
attn_out = phase.cross_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache()
return x, attn_out
def infer_ffn(self, weights, x, attn_out, c_shift_msa, c_scale_msa):
def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out)
if self.clean_cuda_cache:
del attn_out
torch.cuda.empty_cache()
if hasattr(weights, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze()) * weights.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze() * weights.smooth_norm2_bias.tensor
if hasattr(phase, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze()) * phase.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze() * phase.smooth_norm2_bias.tensor
else:
norm2_weight = 1 + c_scale_msa.squeeze()
norm2_bias = c_shift_msa.squeeze()
norm2_out = weights.norm2.apply(x)
norm2_out = phase.norm2.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out.mul_(norm2_weight).add_(norm2_bias)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.infer_dtype)
y = weights.ffn_0.apply(norm2_out)
y = phase.ffn_0.apply(norm2_out)
if self.clean_cuda_cache:
del norm2_out, x, norm2_weight, norm2_bias
torch.cuda.empty_cache()
y = torch.nn.functional.gelu(y, approximate="tanh")
if self.clean_cuda_cache:
torch.cuda.empty_cache()
y = weights.ffn_2.apply(y)
y = phase.ffn_2.apply(y)
return y
......
......@@ -5,41 +5,33 @@ from lightx2v.utils.envs import *
class WanVaceTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.vace_block_nums = len(self.config.vace_layers)
self.vace_blocks_num = len(self.config.vace_layers)
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)}
def infer(self, weights, pre_infer_out):
pre_infer_out.adapter_output["hints"] = self.infer_vace(weights, pre_infer_out)
x = self.infer_main_blocks(weights, pre_infer_out)
pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context)
self.infer_vace_blocks(weights.vace_blocks, pre_infer_out)
x = self.infer_main_blocks(weights.blocks, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed)
def infer_vace(self, weights, pre_infer_out):
c = weights.vace_patch_embedding.apply(pre_infer_out.vace_context.unsqueeze(0).to(self.sensitive_layer_dtype))
def vace_pre_process(self, patch_embedding, vace_context):
c = patch_embedding.apply(vace_context.unsqueeze(0).to(self.sensitive_layer_dtype))
c = c.flatten(2).transpose(1, 2).contiguous().squeeze(0)
return c
def infer_vace_blocks(self, vace_blocks, pre_infer_out):
pre_infer_out.adapter_output["hints"] = []
self.infer_state = "vace"
hints = []
for i in range(self.vace_block_nums):
c, c_skip = self.infer_vace_block(weights.vace_blocks[i], c, pre_infer_out.x, pre_infer_out)
hints.append(c_skip)
if hasattr(self, "weights_stream_mgr"):
self.weights_stream_mgr.init(self.vace_blocks_num, self.phases_num, self.offload_ratio)
self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out)
self.infer_state = "base"
return hints
def infer_vace_block(self, weights, c, x, pre_infer_out):
if hasattr(weights, "before_proj"):
c = weights.before_proj.apply(c) + x
c = self.infer_block(weights, c, pre_infer_out)
c_skip = weights.after_proj.apply(c)
return c, c_skip
if hasattr(self, "weights_stream_mgr"):
self.weights_stream_mgr.init(self.blocks_num, self.phases_num, self.offload_ratio)
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping:
hint_idx = self.vace_blocks_mapping[self.block_idx]
x = x + pre_infer_out.adapter_output["hints"][hint_idx] * pre_infer_out.adapter_output.get("context_scale", 1.0)
return x
......@@ -27,7 +27,7 @@ class WanTransformerWeights(WeightModule):
self.add_module("blocks", self.blocks)
# post blocks weights
# non blocks weights
self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]())
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
......@@ -67,15 +67,6 @@ class WanTransformerAttentionBlock(WeightModule):
self.compute_phases = WeightModuleList(
[
WanModulation(
block_index,
block_prefix,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanSelfAttention(
block_index,
block_prefix,
......@@ -109,7 +100,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.add_module("compute_phases", self.compute_phases)
class WanModulation(WeightModule):
class WanSelfAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
......@@ -131,20 +122,6 @@ class WanModulation(WeightModule):
),
)
class WanSelfAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"norm1",
LN_WEIGHT_REGISTER["Default"](),
......
......@@ -9,8 +9,6 @@ from lightx2v.utils.registry_factory import (
)
# "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
# {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
class WanVaceTransformerWeights(WanTransformerWeights):
def __init__(self, config):
super().__init__(config)
......@@ -44,7 +42,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
def __init__(self, base_block_idx, block_index, task, mm_type, config, block_prefix):
super().__init__(block_index, task, mm_type, config, block_prefix)
if base_block_idx == 0:
self.add_module(
self.compute_phases[0].add_module(
"before_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.before_proj.weight",
......@@ -53,7 +51,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
self.lazy_load_file,
),
)
self.add_module(
self.compute_phases[-1].add_module(
"after_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.after_proj.weight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment