Commit 53eae786 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #92 from ModelTC/dev_gradio

Update gradio
parents 7a8951ba 5fc97e4f
This diff is collapsed.
This diff is collapsed.
#!/bin/bash
lightx2v_path=/path/to/lightx2v
model_path=/path/to/wan
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python gradio_demo.py \
--model_path $model_path \
--server_name 0.0.0.0 \
--server_port 8005
# python gradio_demo_zh.py \
# --model_path $model_path \
# --server_name 0.0.0.0 \
# --server_port 8005
......@@ -56,8 +56,9 @@ class WeightAsyncStreamManager(object):
class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2, offload_gra="phase"):
super().__init__(blocks_num, offload_ratio, phases_num)
self.offload_gra = offload_gra
self.worker_stop_event = threading.Event()
self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
self.disk_task_queue = queue.PriorityQueue()
......@@ -72,7 +73,10 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
def _start_disk_workers(self, num_workers):
for i in range(num_workers):
if self.offload_gra == "phase":
worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
else:
worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
worker.start()
self.disk_workers.append(worker)
......@@ -96,11 +100,34 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _disk_worker_loop_block(self):
while not self.worker_stop_event.is_set():
try:
_, task = self.disk_task_queue.get(timeout=0.5)
if task is None:
break
block_idx, block = task
for phase in block.compute_phases:
phase.load_from_disk()
self.pin_memory_buffer.push(block_idx, block)
with self.task_lock:
if block_idx in self.pending_tasks:
del self.pending_tasks[block_idx]
except queue.Empty:
continue
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _async_prefetch_block(self, weights):
next_block_idx = self.pin_memory_buffer.get_max_block_index()
if next_block_idx < 0:
next_block_idx = 0
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
obj_key = (next_block_idx, phase_idx)
......@@ -114,15 +141,33 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
else:
obj_key = next_block_idx
if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
return
with self.task_lock:
self.pending_tasks[obj_key] = True
block = weights.blocks[next_block_idx]
self.disk_task_queue.put((obj_key, (next_block_idx, block)))
def _sync_prefetch_block(self, weights):
block_idx = 0
while not self.pin_memory_buffer.is_nearly_full():
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
phase = weights.blocks[block_idx].compute_phases[phase_idx]
logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
else:
block = weights.blocks[block_idx]
logger.info(f"Synchronous loading: block={block_idx}")
for phase in block.compute_phases:
phase.load_from_disk()
self.pin_memory_buffer.push(block_idx, block)
block_idx += 1
def prefetch_weights_from_disk(self, weights):
......@@ -132,6 +177,37 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self._sync_prefetch_block(weights)
self.initial_prefetch_done = True
def prefetch_weights(self, block_idx, blocks):
obj_key = block_idx
if not self.pin_memory_buffer.exists(obj_key):
is_loading = False
with self.task_lock:
if obj_key in self.pending_tasks:
is_loading = True
if is_loading:
start_time = time.time()
while not self.pin_memory_buffer.exists(obj_key):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}")
else:
logger.info("Not find prefetch block={block_idx} task. This is a bug.")
with torch.cuda.stream(self.cuda_load_stream):
block = self.pin_memory_buffer.get(obj_key)
block.to_cuda_async()
self.active_weights[2] = (obj_key, block)
with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_block_num:
if self.active_weights[1] is not None:
old_key, old_block = self.active_weights[1]
if self.pin_memory_buffer.exists(old_key):
old_block.to_cpu_async()
self.pin_memory_buffer.pop(old_key)
def prefetch_phase(self, block_idx, phase_idx, blocks):
obj_key = (block_idx, phase_idx)
......@@ -193,32 +269,45 @@ class MemoryBuffer:
self.cache = OrderedDict()
self.max_mem = max_memory_bytes
self.used_mem = 0
self.phases_size_map = {}
self.obj_size_map = {}
self.lock = threading.Lock()
self.insertion_order = []
self.insertion_index = 0
def push(self, key, phase_obj):
def push(self, key, obj):
with self.lock:
if key in self.cache:
return
_, phase_idx = key
if phase_idx not in self.phases_size_map:
self.phases_size_map[phase_idx] = phase_obj.calculate_size()
size = self.phases_size_map[phase_idx]
if hasattr(obj, "compute_phases"):
obj_idx = key
if len(self.obj_size_map) == 0:
_size = 0
for phase in obj.compute_phases:
_size += phase.calculate_size()
self.obj_size_map[0] = _size
size = self.obj_size_map[0]
else:
_, obj_idx = key
if obj_idx not in self.obj_size_map:
self.obj_size_map[obj_idx] = obj.calculate_size()
size = self.obj_size_map[obj_idx]
self.cache[key] = (size, phase_obj, self.insertion_index)
self.cache[key] = (size, obj, self.insertion_index)
self.insertion_order.append((key, self.insertion_index))
self.insertion_index += 1
self.used_mem += size
def _remove_key(self, key):
if key in self.cache:
size, phase, idx = self.cache.pop(key)
size, obj, idx = self.cache.pop(key)
try:
if hasattr(obj, "compute_phases"):
for phase in obj.compute_phases:
phase.clear()
else:
obj.clear()
except Exception as e:
logger.info(f"Error clearing phase: {e}")
logger.info(f"Error clearing obj: {e}")
self.used_mem -= size
self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]
......@@ -226,14 +315,22 @@ class MemoryBuffer:
def get(self, key, default=None):
with self.lock:
if key in self.cache:
size, phase, idx = self.cache[key]
return phase
size, obj, idx = self.cache[key]
return obj
return default
def exists(self, key):
with self.lock:
return key in self.cache
def pop_front(self):
with self.lock:
if not self.insertion_order:
return False
front_key, _ = self.insertion_order[0]
self._remove_key(front_key)
return True
def pop(self, key):
with self.lock:
if key in self.cache:
......@@ -249,7 +346,10 @@ class MemoryBuffer:
with self.lock:
if not self.cache:
return -1
if isinstance(list(self.cache.keys())[-1], tuple):
return (list(self.cache.keys())[-1][0] + 1) % 40
else:
return (list(self.cache.keys())[-1] + 1) % 40
def clear(self):
with self.lock:
......
import torch
from .utils import compute_freqs, compute_freqs_dist, compute_freqs_audio, compute_freqs_audio_dist, apply_rotary_emb, apply_rotary_emb_chunk
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager,
)
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
from loguru import logger
import os
from functools import partial
class WanTransformerInfer(BaseTransformerInfer):
......@@ -21,10 +20,12 @@ class WanTransformerInfer(BaseTransformerInfer):
self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb
if config.get("rotary_chunk", False):
chunk_size = config.get("rotary_chunk_size", 100)
self.apply_rotary_emb_func = partial(apply_rotary_emb_chunk, chunk_size=chunk_size)
else:
self.apply_rotary_emb_func = apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
if self.config["cpu_offload"]:
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
......@@ -32,7 +33,10 @@ class WanTransformerInfer(BaseTransformerInfer):
offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_offload
else:
self.infer_func = self._infer_with_lazy_offload
elif offload_granularity == "phase":
if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_phases_offload
......@@ -52,6 +56,7 @@ class WanTransformerInfer(BaseTransformerInfer):
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
else:
self.infer_func = self._infer_without_offload
......@@ -68,10 +73,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return cu_seqlens_q, cu_seqlens_k
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
......@@ -96,7 +101,44 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(self.blocks_num):
if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx)
block.to_cuda()
self.weights_stream_mgr.active_weights[0] = (block_idx, block)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(
self.weights_stream_mgr.active_weights[0][1],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights)
if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(weights.blocks_num):
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
......@@ -133,11 +175,18 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases()
if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(weights.blocks_num):
......@@ -198,22 +247,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def zero_temporal_component_in_3DRoPE(self, valid_token_length, rotary_emb=None):
if rotary_emb is None:
return None
self.use_real = False
rope_t_dim = 44
if self.use_real:
freqs_cos, freqs_sin = rotary_emb
freqs_cos[valid_token_length:, :, :rope_t_dim] = 0
freqs_sin[valid_token_length:, :, :rope_t_dim] = 0
return freqs_cos, freqs_sin
else:
freqs_cis = rotary_emb
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
x = self.infer_block(
weights.blocks[block_idx],
......@@ -225,12 +259,6 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs,
context,
)
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
......@@ -290,23 +318,14 @@ class WanTransformerInfer(BaseTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention:
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else:
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i)
k_lens = torch.empty_like(seq_lens).fill_(freqs_i.size(0))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens)
if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias
......@@ -322,7 +341,6 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
mask_map=self.mask_map,
)
else:
attn_out = self.parallel_attention(
......@@ -388,6 +406,7 @@ class WanTransformerInfer(BaseTransformerInfer):
q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
)
img_attn_out = weights.cross_attn_2.apply(
q=q,
k=k_img,
......
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import *
......@@ -20,45 +19,6 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
return freqs_i
def compute_freqs_audio_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
......@@ -115,7 +75,7 @@ def apply_rotary_emb(x, freqs_i):
return x_i.to(torch.bfloat16)
def apply_rotary_emb_chunk(x, freqs_i, chunk_size=100, remaining_chunk_size=100):
def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
n = x.size(1)
seq_len = freqs_i.size(0)
......
import os
import torch
from functools import lru_cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment