Unverified Commit 74eeb429 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

reconstruct disk offload and fix lightx2v_platform bugs (#558)


Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent f7cdbcb5
......@@ -23,35 +23,6 @@ class WeightModule:
if hasattr(parameter, "load"):
parameter.load(weight_dict)
def calculate_size(self):
total_size = 0
for _, module in self._modules.items():
if hasattr(module, "_calculate_size"):
total_size += module._calculate_size()
for _, parameter in self._parameters.items():
if hasattr(parameter, "_calculate_size"):
total_size += parameter._calculate_size()
return total_size
def load_from_disk(self):
for _, module in self._modules.items():
if hasattr(module, "load_from_disk"):
module.load_from_disk()
for _, parameter in self._parameters.items():
if hasattr(parameter, "load_from_disk"):
parameter.load_from_disk()
def clear(self):
for _, module in self._modules.items():
if hasattr(module, "clear"):
module.clear()
for _, parameter in self._parameters.items():
if hasattr(parameter, "clear"):
parameter.clear()
def state_dict(self, destination=None):
if destination is None:
destination = {}
......@@ -74,6 +45,14 @@ class WeightModule:
module.load_state_dict(destination, block_index, adapter_block_index)
return destination
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
for _, param in self._parameters.items():
if param is not None:
param.load_state_dict_from_disk(block_index, adapter_block_index)
for _, module in self._modules.items():
if module is not None:
module.load_state_dict_from_disk(block_index, adapter_block_index)
def named_parameters(self, prefix=""):
for name, param in self._parameters.items():
if param is not None:
......
import gc
import queue
import threading
import time
from collections import OrderedDict
import torch
from loguru import logger
from packaging.version import parse
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class WeightAsyncStreamManager(object):
def __init__(self, offload_granularity):
self.offload_granularity = offload_granularity
self.init_stream = torch.cuda.Stream(priority=0)
self.init_stream = torch_device_module.Stream(priority=0)
self.need_init_first_buffer = True
torch_version = parse(torch.__version__.split("+")[0])
if torch_version >= parse("2.7"):
self.cuda_load_stream = torch.cuda.Stream(priority=1)
self.compute_stream = torch.cuda.Stream(priority=1)
if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
self.cuda_load_stream = torch_device_module.Stream(priority=1)
self.compute_stream = torch_device_module.Stream(priority=1)
else:
self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.compute_stream = torch.cuda.Stream(priority=-1)
self.cuda_load_stream = torch_device_module.Stream(priority=0)
self.compute_stream = torch_device_module.Stream(priority=-1)
def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
self.need_init_first_buffer = True
if self.offload_granularity == "block":
assert blocks_cpu_buffer is not None
self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
elif self.offload_granularity == "phase":
assert phases_cpu_buffer is not None
self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
else:
raise NotImplementedError
def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
self.need_init_first_buffer = True
if self.offload_granularity == "block":
assert blocks_cuda_buffer is not None
self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
......@@ -32,17 +42,32 @@ class WeightAsyncStreamManager(object):
raise NotImplementedError
def init_first_buffer(self, blocks, adapter_block_idx=None):
if self.offload_granularity == "block":
with torch.cuda.stream(self.init_stream):
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
else:
with torch.cuda.stream(self.init_stream):
self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
with torch_device_module.stream(self.init_stream):
if hasattr(self, "cpu_buffers"):
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0].state_dict(), 0, adapter_block_idx)
else:
if self.offload_granularity == "block":
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
else:
self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
self.init_stream.synchronize()
self.need_init_first_buffer = False
def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
with torch.cuda.stream(self.cuda_load_stream):
self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"):
self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
else:
self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"):
self.cpu_buffers[phase_idx].load_state_dict_from_disk(block_idx, adapter_block_idx)
self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[phase_idx].state_dict(), block_idx, adapter_block_idx)
else:
self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
def swap_blocks(self):
self.cuda_load_stream.synchronize()
......@@ -52,347 +77,6 @@ class WeightAsyncStreamManager(object):
self.cuda_buffers[0],
)
def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
with torch.cuda.stream(self.cuda_load_stream):
self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
def swap_phases(self):
self.cuda_load_stream.synchronize()
self.compute_stream.synchronize()
class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
def __init__(
self,
blocks_num,
offload_ratio=1,
phases_num=1,
num_disk_workers=1,
max_memory=2,
offload_gra="phase",
):
super().__init__(blocks_num, offload_ratio, phases_num)
self.offload_gra = offload_gra
self.worker_stop_event = threading.Event()
self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
self.disk_task_queue = queue.PriorityQueue()
self.disk_workers = []
self.release_workers = []
self._start_disk_workers(num_disk_workers)
self.initial_prefetch_done = False
self.pending_tasks = {}
self.task_lock = threading.Lock()
self.last_used_time = {}
self.time_lock = threading.Lock()
def _start_disk_workers(self, num_workers):
for i in range(num_workers):
if self.offload_gra == "phase":
worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
else:
worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
worker.start()
self.disk_workers.append(worker)
def _disk_worker_loop(self):
while not self.worker_stop_event.is_set():
try:
_, task = self.disk_task_queue.get(timeout=0.5)
if task is None:
break
block_idx, phase_idx, phase = task
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
with self.task_lock:
if (block_idx, phase_idx) in self.pending_tasks:
del self.pending_tasks[(block_idx, phase_idx)]
except queue.Empty:
continue
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _disk_worker_loop_block(self):
while not self.worker_stop_event.is_set():
try:
_, task = self.disk_task_queue.get(timeout=0.5)
if task is None:
break
block_idx, block = task
for phase in block.compute_phases:
phase.load_from_disk()
self.pin_memory_buffer.push(block_idx, block)
with self.task_lock:
if block_idx in self.pending_tasks:
del self.pending_tasks[block_idx]
except queue.Empty:
continue
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _async_prefetch_block(self, blocks, next_block_idx=None):
if next_block_idx is None:
next_block_idx = self.pin_memory_buffer.get_max_block_index()
if next_block_idx < 0:
next_block_idx = 0
if next_block_idx == self.blocks_num:
return
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
obj_key = (next_block_idx, phase_idx)
if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
continue
with self.task_lock:
self.pending_tasks[obj_key] = True
phase = blocks[next_block_idx].compute_phases[phase_idx]
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
else:
obj_key = next_block_idx
if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
return
with self.task_lock:
self.pending_tasks[obj_key] = True
block = blocks[next_block_idx]
self.disk_task_queue.put((obj_key, (next_block_idx, block)))
def _sync_prefetch_block(self, blocks):
block_idx = 0
while not self.pin_memory_buffer.is_nearly_full():
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
phase = blocks[block_idx].compute_phases[phase_idx]
logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
else:
block = blocks[block_idx]
logger.info(f"Synchronous loading: block={block_idx}")
for phase in block.compute_phases:
phase.load_from_disk()
self.pin_memory_buffer.push(block_idx, block)
block_idx += 1
if block_idx == self.blocks_num:
break
def prefetch_weights_from_disk(self, blocks):
if self.initial_prefetch_done:
return
self._sync_prefetch_block(blocks)
self.initial_prefetch_done = True
def prefetch_weights(self, block_idx, blocks):
obj_key = block_idx
if not self.pin_memory_buffer.exists(obj_key):
is_loading = False
with self.task_lock:
if obj_key in self.pending_tasks:
is_loading = True
if is_loading:
start_time = time.time()
while not self.pin_memory_buffer.exists(obj_key):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}")
else:
logger.info("Not find prefetch block={block_idx} task.")
logger.info("Sync prefetch block={block_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
for phase_idx in self.phases_num:
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 15:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
block = self.pin_memory_buffer.get(obj_key)
block.to_cuda_async()
self.cuda_buffers[2] = (obj_key, block)
with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_blocks_num:
if self.cuda_buffers[1] is not None:
old_key, old_block = self.cuda_buffers[1]
if self.pin_memory_buffer.exists(old_key):
old_block.to_cpu_async()
self.pin_memory_buffer.pop(old_key)
def prefetch_phase(self, block_idx, phase_idx, blocks):
obj_key = (block_idx, phase_idx)
if not self.pin_memory_buffer.exists(obj_key):
is_loading = False
with self.task_lock:
if obj_key in self.pending_tasks:
is_loading = True
if is_loading:
start_time = time.time()
while not self.pin_memory_buffer.exists(obj_key):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
else:
logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
phase = self.pin_memory_buffer.get(obj_key)
phase.to_cuda_async()
self.cuda_buffers[2] = (obj_key, phase)
with torch.cuda.stream(self.cpu_load_stream):
if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
if self.cuda_buffers[1] is not None:
old_key, old_phase = self.cuda_buffers[1]
if self.pin_memory_buffer.exists(old_key):
old_phase.to_cpu_async()
self.pin_memory_buffer.pop(old_key)
def shutdown(self):
self.worker_stop_event.set()
while not self.disk_task_queue.empty():
try:
self.disk_task_queue.get_nowait()
except queue.Empty:
continue
for _ in self.disk_workers:
self.disk_task_queue.put((0, None))
for worker in self.disk_workers:
worker.join(timeout=5)
for worker in self.release_workers:
worker.join(timeout=5)
logger.info("All worker threads have been closed")
def clear(self):
self.pin_memory_buffer.clear()
self.shutdown()
class MemoryBuffer:
def __init__(self, max_memory_bytes=8 * (1024**3)):
self.cache = OrderedDict()
self.max_mem = max_memory_bytes
self.used_mem = 0
self.obj_size_map = {}
self.lock = threading.Lock()
self.insertion_order = []
self.insertion_index = 0
def push(self, key, obj):
with self.lock:
if key in self.cache:
return
if hasattr(obj, "compute_phases"):
obj_idx = key
if len(self.obj_size_map) == 0:
_size = 0
for phase in obj.compute_phases:
_size += phase.calculate_size()
self.obj_size_map[0] = _size
size = self.obj_size_map[0]
else:
_, obj_idx = key
if obj_idx not in self.obj_size_map:
self.obj_size_map[obj_idx] = obj.calculate_size()
size = self.obj_size_map[obj_idx]
self.cache[key] = (size, obj, self.insertion_index)
self.insertion_order.append((key, self.insertion_index))
self.insertion_index += 1
self.used_mem += size
def _remove_key(self, key):
if key in self.cache:
size, obj, idx = self.cache.pop(key)
try:
if hasattr(obj, "compute_phases"):
for phase in obj.compute_phases:
phase.clear()
else:
obj.clear()
except Exception as e:
logger.info(f"Error clearing obj: {e}")
self.used_mem -= size
self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]
def get(self, key, default=None):
with self.lock:
if key in self.cache:
size, obj, idx = self.cache[key]
return obj
return default
def exists(self, key):
with self.lock:
return key in self.cache
def pop_front(self):
with self.lock:
if not self.insertion_order:
return False
front_key, _ = self.insertion_order[0]
self._remove_key(front_key)
return True
def pop(self, key):
with self.lock:
if key in self.cache:
self._remove_key(key)
return True
return False
def is_nearly_full(self):
with self.lock:
return self.used_mem >= self.max_mem * 0.9
def get_max_block_index(self):
with self.lock:
if not self.cache:
return -1
if isinstance(list(self.cache.keys())[-1], tuple):
return (list(self.cache.keys())[-1][0] + 1) % 40
else:
return (list(self.cache.keys())[-1] + 1) % 40
def clear(self):
with self.lock:
for key in list(self.cache.keys()):
self._remove_key(key)
self.insertion_order = []
self.insertion_index = 0
self.used_mem = 0
torch.cuda.empty_cache()
gc.collect()
......@@ -73,9 +73,10 @@ class FlashAttn3Weight(AttnWeightTemplate):
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
if model_cls is not None and model_cls in ["hunyuan_video_1.5"]:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func_v3(
q,
k,
......
......@@ -30,3 +30,6 @@ class AttnWeightTemplate(metaclass=ABCMeta):
def load_state_dict(self, destination, block_index, adapter_block_inde=None):
return {}
def load_state_dict_from_disk(self, block_index, adapter_block_inde=None):
pass
......@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class Conv2dWeightTemplate(metaclass=ABCMeta):
......@@ -34,8 +35,8 @@ class Conv2dWeight(Conv2dWeightTemplate):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
self.weight = weight_dict[self.weight_name].to(AI_DEVICE)
self.bias = weight_dict[self.bias_name].to(AI_DEVICE) if self.bias_name is not None else None
def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
......@@ -47,9 +48,9 @@ class Conv2dWeight(Conv2dWeightTemplate):
self.bias = self.bias.cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
self.weight = self.weight.to(AI_DEVICE, non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking)
self.bias = self.bias.to(AI_DEVICE, non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
......@@ -58,10 +59,3 @@ class Conv2dWeight(Conv2dWeightTemplate):
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
def clear(self):
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
......@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class Conv3dWeightTemplate(metaclass=ABCMeta):
......@@ -70,9 +71,9 @@ class Conv3dWeight(Conv3dWeightTemplate):
return input_tensor
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
......@@ -91,10 +92,3 @@ class Conv3dWeight(Conv3dWeightTemplate):
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias # .cpu().detach().clone()
return destination
def clear(self):
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
......@@ -5,12 +5,14 @@ import torch
import torch.nn.functional as F
from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class EmbeddingWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
......@@ -19,7 +21,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
def load(self, weight_dict):
if not self.lazy_load:
if self.create_cuda_buffer:
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
self.weight_cuda_buffer = weight_dict[self.weight_name].to(AI_DEVICE)
else:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
......@@ -32,7 +34,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self.weight = weight_dict[self.weight_name]
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
......
This diff is collapsed.
......@@ -5,16 +5,18 @@ import torch
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
from .triton_ops import norm_infer
class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
self.weight_name = weight_name
self.bias_name = bias_name
self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
......@@ -23,53 +25,71 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load(self, weight_dict):
if not self.lazy_load:
if self.create_cuda_buffer:
if self.weight_name is not None:
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda().t()
if self.bias_name is not None:
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
if self.create_cuda_buffer:
self._load_cuda_buffers(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffers()
else:
self._load_default_tensors(weight_dict)
def _load_default_tensors(self, weight_dict):
if not self.lazy_load and self.weight_name is not None:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_tensor = weight_dict[self.weight_name]
self.pin_weight = self._create_cpu_pin_tensor(weight_tensor)
bias_tensor = weight_dict[self.bias_name] if self.bias_name is not None else None
self.pin_bias = self._create_cpu_pin_tensor(bias_tensor) if bias_tensor is not None else None
self.bias = None
del weight_dict[self.weight_name]
else:
if self.weight_name is not None:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
self.pin_bias = None
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
else:
self.weight = None
self.bias = None
def _calculate_size(self):
if self.weight is None:
return 0
if self.bias is not None:
return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()
return self.weight.numel() * self.weight.element_size()
def clear(self):
attrs = ["weight", "bias", "pinned_weight", "pinned_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
self.weight = weight_dict[self.weight_name]
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
else:
self.weight = None
self.bias = None
def _get_tensor(self, name, weight_dict=None, use_infer_dtype=False):
if name is None:
return None
if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[name]
return tensor
def _create_cpu_pin_tensor(self, tensor):
if tensor is None:
return None
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffers(self, weight_dict):
weight_tensor = self._get_tensor(self.weight_name, weight_dict, use_infer_dtype=self.lazy_load)
if weight_tensor is not None:
self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE)
bias_tensor = self._get_tensor(self.bias_name, weight_dict, use_infer_dtype=self.lazy_load)
if bias_tensor is not None:
self.bias_cuda_buffer = bias_tensor.to(AI_DEVICE)
def _load_cpu_pin_buffers(self):
weight_tensor = self._get_tensor(self.weight_name, use_infer_dtype=True)
if weight_tensor is not None:
self.pin_weight = self._create_cpu_pin_tensor(weight_tensor)
else:
self.weight = None
bias_tensor = self._get_tensor(self.bias_name, use_infer_dtype=True)
if bias_tensor is not None:
self.pin_bias = self._create_cpu_pin_tensor(bias_tensor)
else:
self.bias = None
self.pin_bias = None
@abstractmethod
def apply(self, input_tensor):
......@@ -81,11 +101,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
def to_cuda(self, non_blocking=False):
if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
else:
self.weight = None
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
else:
self.bias = None
......@@ -129,28 +149,33 @@ class LNWeightTemplate(metaclass=ABCMeta):
else:
self.bias = None
@LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
else:
self.weight = None
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE()).pin_memory()
if self.is_post_adapter:
assert adapter_block_index is not None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE())
else:
self.bias = None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
@LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
if self.sensitive_layer_dtype != self.infer_dtype:
......@@ -169,25 +194,8 @@ class LNWeight(LNWeightTemplate):
@LN_WEIGHT_REGISTER("Triton")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
else:
self.weight = None
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE()).pin_memory()
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE())
else:
self.bias = None
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
input_tensor = norm_infer(input_tensor, self.weight, self.bias, self.eps)
......
......@@ -5,6 +5,7 @@ import torch
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
try:
import sgl_kernel
......@@ -13,10 +14,11 @@ except ImportError:
class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
self.weight_name = weight_name
self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
......@@ -25,26 +27,45 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.config = {}
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffer(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffer()
else:
self._load_default_tensors(weight_dict)
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
if self.create_cuda_buffer:
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_tensor = weight_dict[self.weight_name]
self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
del weight_dict[self.weight_name]
else:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
def clear(self):
attrs = ["weight", "pinned_weight"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
self.weight = weight_dict[self.weight_name]
def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(self.weight_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.weight_name]
return tensor
def _create_cpu_pin_weight(self, tensor):
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffer(self, weight_dict):
weight_tensor = self._get_weight_tensor(weight_dict, use_infer_dtype=self.lazy_load)
self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE)
def _load_cpu_pin_buffer(self):
weight_tensor = self._get_weight_tensor(use_infer_dtype=True)
self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
@abstractmethod
def apply(self, input_tensor):
......@@ -55,7 +76,7 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.config = config
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
......@@ -63,31 +84,6 @@ class RMSWeightTemplate(metaclass=ABCMeta):
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def _calculate_size(self):
return self.weight.numel() * self.weight.element_size()
@RMS_WEIGHT_REGISTER("Default")
class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, input_tensor):
if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = self._norm(input_tensor).type_as(input_tensor) * self.weight
else:
input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * self.weight
return input_tensor
def state_dict(self, destination=None):
if destination is None:
destination = {}
......@@ -106,6 +102,32 @@ class RMSWeight(RMSWeightTemplate):
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
@RMS_WEIGHT_REGISTER("Default")
class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, input_tensor):
if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = self._norm(input_tensor).type_as(input_tensor) * self.weight
else:
input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * self.weight
return input_tensor
@RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight):
......@@ -113,18 +135,13 @@ class RMSWeightSgl(RMSWeight):
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
):
super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
if sgl_kernel is not None and self.sensitive_layer_dtype == self.infer_dtype:
......@@ -146,8 +163,8 @@ class RMSWeightSgl(RMSWeight):
@RMS_WEIGHT_REGISTER("fp32_variance")
class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
input_dtype = input_tensor.dtype
......@@ -165,8 +182,8 @@ class RMSWeightFP32(RMSWeight):
@RMS_WEIGHT_REGISTER("self_forcing")
class RMSWeightSF(RMSWeight):
def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
......
......@@ -4,52 +4,64 @@ import torch
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import TENSOR_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
@TENSOR_REGISTER("Default")
class DefaultTensor:
def __init__(self, tensor_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.tensor_name = tensor_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype).pin_memory()
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffer(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffer()
else:
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
self._load_default_tensors(weight_dict)
def load(self, weight_dict):
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
if self.create_cuda_buffer:
self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda()
device = weight_dict[self.tensor_name].device
if device.type == "cpu":
tensor = weight_dict[self.tensor_name]
self.pin_tensor = self._create_cpu_pin_tensor(tensor)
del weight_dict[self.tensor_name]
else:
device = weight_dict[self.tensor_name].device
if device.type == "cpu":
tensor_shape = weight_dict[self.tensor_name].shape
tensor_dtype = weight_dict[self.tensor_name].dtype
self.pin_tensor = torch.empty(tensor_shape, pin_memory=True, dtype=tensor_dtype)
self.pin_tensor.copy_(weight_dict[self.tensor_name])
del weight_dict[self.tensor_name]
else:
self.tensor = weight_dict[self.tensor_name]
def clear(self):
attrs = ["tensor", "pinned_tensor"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
def _calculate_size(self):
return self.tensor.numel() * self.tensor.element_size()
self.tensor = weight_dict[self.tensor_name]
def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(self.tensor_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.tensor_name]
return tensor
def _create_cpu_pin_tensor(self, tensor):
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffer(self, weight_dict):
tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load)
self.tensor_cuda_buffer = tensor.to(AI_DEVICE)
def _load_cpu_pin_buffer(self):
tensor = self._get_tensor(use_infer_dtype=True)
self.pin_tensor = self._create_cpu_pin_tensor(tensor)
def to_cuda(self, non_blocking=False):
self.tensor = self.pin_tensor.cuda(non_blocking=non_blocking)
self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_tensor"):
......@@ -69,8 +81,18 @@ class DefaultTensor:
tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
if tensor_name not in destination:
self.tensor = None
return
self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
self.pin_tensor = self.pin_tensor.copy_(tensor)
del tensor
......@@ -62,17 +62,13 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.VAE_IMAGE_SIZE = 1024 * 1024
self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device(AI_DEVICE)
self.dtype = torch.bfloat16
self.load()
def load(self):
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
if not self.cpu_offload:
self.text_encoder = self.text_encoder.to(self.device)
self.text_encoder = self.text_encoder.to(AI_DEVICE)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config["model_path"], "tokenizer"))
if self.config["task"] == "i2i":
......@@ -98,7 +94,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
@torch.no_grad()
def infer(self, text, image_list=None):
if self.cpu_offload:
self.text_encoder.to(self.device)
self.text_encoder.to(AI_DEVICE)
if image_list is not None:
condition_image_list = []
......@@ -133,7 +129,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images=condition_image_list,
padding=True,
return_tensors="pt",
).to(torch.device(self.device))
).to(AI_DEVICE)
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
......@@ -156,7 +152,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
txt = [template.format(e) for e in text]
image_info = {}
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(AI_DEVICE)
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
......@@ -172,7 +168,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=AI_DEVICE)
prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape
......
......@@ -515,7 +515,7 @@ class T5Encoder(nn.Module):
e = pos_bias
else:
lq, lk = x.size(1), x.size(1)
rel_pos = torch.arange(lk, device="cuda").unsqueeze(0) - torch.arange(lq, device="cuda").unsqueeze(1)
rel_pos = torch.arange(lk, device=AI_DEVICE).unsqueeze(0) - torch.arange(lq, device=AI_DEVICE).unsqueeze(1)
num_buckets = block.pos_embedding.weight.shape[0] // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
......@@ -532,28 +532,21 @@ class T5Encoder(nn.Module):
return x
def forward_with_offload(self, ids, mask=None):
self.token_embedding = self.token_embedding.to("cuda")
self.pos_embedding = self.pos_embedding.to("cuda") if self.pos_embedding is not None else None
self.token_embedding = self.token_embedding.to(AI_DEVICE)
self.pos_embedding = self.pos_embedding.to(AI_DEVICE) if self.pos_embedding is not None else None
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
self.norm = self.norm.to("cuda")
self.norm = self.norm.to(AI_DEVICE)
for block_idx in range(len(self.blocks)):
self.block_idx = block_idx
if block_idx == 0:
self.offload_manager.cuda_buffers[0].load_state_dict(
self.blocks[block_idx].state_dict(),
block_idx,
)
if block_idx < len(self.blocks) - 1:
self.offload_manager.prefetch_weights(block_idx + 1, self.blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.forward_block_with_offload(self.offload_manager.cuda_buffers[0], x, mask, pos_bias=e)
self.offload_manager.swap_blocks()
self.offload_manager.cuda_buffers[0].load_state_dict(
self.blocks[block_idx].state_dict(),
block_idx,
)
x = self.forward_block_with_offload(self.offload_manager.cuda_buffers[0], x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
......
......@@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerInfer):
......@@ -101,8 +102,8 @@ class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerI
def infer_using_cache(self, infer_module_out):
residual_img = self.residual_cache[self.scheduler.infer_condition]
residual_txt = self.residual_cache_txt[self.scheduler.infer_condition]
infer_module_out.img.add_(residual_img.cuda())
infer_module_out.txt.add_(residual_txt.cuda())
infer_module_out.img.add_(residual_img.to(AI_DEVICE))
infer_module_out.txt.add_(residual_txt.to(AI_DEVICE))
def clear(self):
self.accumulated_err = {True: 0.0, False: 0.0}
......
......@@ -2,6 +2,9 @@ import torch
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.hunyuan_video.infer.transformer_infer import HunyuanVideo15TransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer):
......@@ -26,6 +29,6 @@ class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer):
self.offload_manager.init_first_buffer(weights.double_blocks)
if block_idx < self.double_blocks_num - 1:
self.offload_manager.prefetch_weights(block_idx + 1, weights.double_blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
with torch_device_module.stream(self.offload_manager.compute_stream):
infer_module_out.img, infer_module_out.txt = self.infer_double_block(self.offload_manager.cuda_buffers[0], infer_module_out)
self.offload_manager.swap_blocks()
......@@ -234,16 +234,11 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
attention_module=weights.self_attention,
seq_p_group=self.seq_p_group,
use_fp8_comm=self.seq_p_fp8_comm,
model_cls=self.config["model_cls"],
)
else:
attn_out = weights.self_attention.apply(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=seqlen,
max_seqlen_kv=seqlen,
q=query, k=key, v=value, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=seqlen, max_seqlen_kv=seqlen, model_cls=self.config["model_cls"]
)
img_attn, txt_attn = attn_out[:img_seqlen], attn_out[img_seqlen:]
......
......@@ -32,7 +32,8 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
self.seq_p_group = None
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.remove_keys = ["byt5_in", "vision_in"]
self.remove_keys = []
self.remove_keys.extend(["byt5_in", "vision_in"])
self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized:
assert self.config.get("dit_quant_scheme", "Default") in [
......@@ -98,7 +99,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
self.transformer_infer = self.transformer_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -176,7 +177,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config["parallel"]:
if self.device.type != "cpu" and dist.is_initialized():
device = dist.get_rank()
else:
device = str(self.device)
......
......@@ -24,7 +24,7 @@ class HunyuanVideo15TransformerWeights(WeightModule):
if config["cpu_offload"]:
if config.get("offload_granularity", "block") == "block":
self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList(
self.offload_block_cuda_buffers = WeightModuleList(
[
MMDoubleStreamBlock(
i,
......@@ -36,8 +36,8 @@ class HunyuanVideo15TransformerWeights(WeightModule):
for i in range(self.offload_blocks_num)
]
)
self.add_module("offload_block_buffers", self.offload_block_buffers)
self.offload_phase_buffers = None
self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers)
self.offload_phase_cuda_buffers = None
def non_block_weights_to_cuda(self):
self.final_layer.to_cuda()
......@@ -47,23 +47,24 @@ class HunyuanVideo15TransformerWeights(WeightModule):
class MMDoubleStreamBlock(WeightModule):
def __init__(self, block_index, task, config, block_prefix="double_blocks", is_offload_buffer=False):
def __init__(self, block_index, task, config, block_prefix="double_blocks", create_cuda_buffer=False, create_cpu_buffer=False):
super().__init__()
self.block_index = block_index
self.task = task
self.config = config
self.is_offload_buffer = is_offload_buffer
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = False
self.lazy_load_file = None
self.add_module(
"img_branch",
MMDoubleStreamBlockImgBranch(block_index, task, config, block_prefix, is_offload_buffer),
MMDoubleStreamBlockImgBranch(block_index, task, config, block_prefix, create_cuda_buffer, create_cpu_buffer),
)
self.add_module(
"txt_branch",
MMDoubleStreamBlockTxtBranch(block_index, task, config, block_prefix, is_offload_buffer),
MMDoubleStreamBlockTxtBranch(block_index, task, config, block_prefix, create_cuda_buffer, create_cpu_buffer),
)
attention_weights_cls = ATTN_WEIGHT_REGISTER[self.config["attn_type"]]
self.add_module("self_attention", attention_weights_cls())
......@@ -75,7 +76,7 @@ class MMDoubleStreamBlock(WeightModule):
class MMDoubleStreamBlockImgBranch(WeightModule):
def __init__(self, block_index, task, config, block_prefix="double_blocks", is_offload_buffer=False):
def __init__(self, block_index, task, config, block_prefix="double_blocks", create_cuda_buffer=False, create_cpu_buffer=False):
super().__init__()
self.block_index = block_index
self.task = task
......@@ -93,7 +94,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mod.linear.weight",
f"{block_prefix}.{self.block_index}.img_mod.linear.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -103,6 +105,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -112,7 +116,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_q.weight",
f"{block_prefix}.{self.block_index}.img_attn_q.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -122,7 +127,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_k.weight",
f"{block_prefix}.{self.block_index}.img_attn_k.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -132,7 +138,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_v.weight",
f"{block_prefix}.{self.block_index}.img_attn_v.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -141,7 +148,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
"img_attn_q_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.img_attn_q_norm.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -150,7 +158,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
"img_attn_k_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.img_attn_k_norm.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -160,7 +169,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_proj.weight",
f"{block_prefix}.{self.block_index}.img_attn_proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -170,6 +180,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -179,7 +191,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mlp.fc1.weight",
f"{block_prefix}.{self.block_index}.img_mlp.fc1.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -189,7 +202,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mlp.fc2.weight",
f"{block_prefix}.{self.block_index}.img_mlp.fc2.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -197,7 +211,7 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
class MMDoubleStreamBlockTxtBranch(WeightModule):
def __init__(self, block_index, task, config, block_prefix="double_blocks", is_offload_buffer=False):
def __init__(self, block_index, task, config, block_prefix="double_blocks", create_cuda_buffer=False, create_cpu_buffer=False):
super().__init__()
self.block_index = block_index
self.task = task
......@@ -215,7 +229,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mod.linear.weight",
f"{block_prefix}.{self.block_index}.txt_mod.linear.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -225,6 +240,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -234,7 +251,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_q.weight",
f"{block_prefix}.{self.block_index}.txt_attn_q.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -244,7 +262,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_k.weight",
f"{block_prefix}.{self.block_index}.txt_attn_k.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -254,7 +273,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_v.weight",
f"{block_prefix}.{self.block_index}.txt_attn_v.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -263,7 +283,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
"txt_attn_q_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.txt_attn_q_norm.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -272,7 +293,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
"txt_attn_k_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.txt_attn_k_norm.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -282,7 +304,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_proj.weight",
f"{block_prefix}.{self.block_index}.txt_attn_proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -292,6 +315,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -301,7 +326,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mlp.fc1.weight",
f"{block_prefix}.{self.block_index}.txt_mlp.fc1.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -311,7 +337,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mlp.fc2.weight",
f"{block_prefix}.{self.block_index}.txt_mlp.fc2.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -333,6 +360,8 @@ class FinalLayerWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"](
"final_layer.adaLN_modulation.1.weight",
"final_layer.adaLN_modulation.1.bias",
False,
False,
self.lazy_load,
self.lazy_load_file,
),
......@@ -342,6 +371,8 @@ class FinalLayerWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"](
"final_layer.linear.weight",
"final_layer.linear.bias",
False,
False,
self.lazy_load,
self.lazy_load_file,
),
......@@ -351,6 +382,8 @@ class FinalLayerWeights(WeightModule):
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
False,
False,
self.lazy_load,
self.lazy_load_file,
),
......
......@@ -2,6 +2,9 @@ import torch
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.qwen_image.infer.transformer_infer import QwenImageTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
......@@ -37,7 +40,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
if block_idx < self.num_blocks - 1:
self.offload_manager.prefetch_weights(block_idx + 1, block_weights.blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
with torch_device_module.stream(self.offload_manager.compute_stream):
encoder_hidden_states, hidden_states = self.infer_block(
block_weight=self.offload_manager.cuda_buffers[0], hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
)
......
......@@ -8,7 +8,6 @@ from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
from lightx2v_platform.base.global_var import AI_DEVICE
from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer
from .infer.post_infer import QwenImagePostInfer
......@@ -125,7 +124,7 @@ class QwenImageTransformerModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config["parallel"]:
if self.device.type != "cpu" and dist.is_initialized():
device = dist.get_rank()
else:
device = str(self.device)
......@@ -284,7 +283,7 @@ class QwenImageTransformerModel:
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
def to_cpu(self):
self.pre_weight.to_cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment