Unverified Commit 51be3ad2 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] remove d2h of cpu-offload infer (#476)

parent 2559b3e7
...@@ -56,12 +56,11 @@ ...@@ -56,12 +56,11 @@
56 56
], ],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3", "attn_type": "sage_attn2",
"do_true_cfg": true, "do_true_cfg": true,
"true_cfg_scale": 4.0, "true_cfg_scale": 4.0,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block",
"mm_config": {},
"CONDITION_IMAGE_SIZE": 147456, "CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": true "USE_IMAGE_ID_IN_PROMPT": true
} }
...@@ -61,7 +61,6 @@ ...@@ -61,7 +61,6 @@
"true_cfg_scale": 4.0, "true_cfg_scale": 4.0,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block",
"mm_config": {},
"CONDITION_IMAGE_SIZE": 1048576, "CONDITION_IMAGE_SIZE": 1048576,
"USE_IMAGE_ID_IN_PROMPT": false "USE_IMAGE_ID_IN_PROMPT": false
} }
...@@ -82,6 +82,5 @@ ...@@ -82,6 +82,5 @@
"attn_type": "flash_attn3", "attn_type": "flash_attn3",
"do_true_cfg": false, "do_true_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block"
"mm_config": {}
} }
...@@ -59,7 +59,6 @@ ...@@ -59,7 +59,6 @@
"attn_type": "flash_attn3", "attn_type": "flash_attn3",
"do_true_cfg": true, "do_true_cfg": true,
"true_cfg_scale": 4.0, "true_cfg_scale": 4.0,
"mm_config": {},
"CONDITION_IMAGE_SIZE": 1048576, "CONDITION_IMAGE_SIZE": 1048576,
"USE_IMAGE_ID_IN_PROMPT": false "USE_IMAGE_ID_IN_PROMPT": false
} }
...@@ -59,7 +59,6 @@ ...@@ -59,7 +59,6 @@
"attn_type": "flash_attn3", "attn_type": "flash_attn3",
"do_true_cfg": true, "do_true_cfg": true,
"true_cfg_scale": 4.0, "true_cfg_scale": 4.0,
"mm_config": {},
"CONDITION_IMAGE_SIZE": 147456, "CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": true "USE_IMAGE_ID_IN_PROMPT": true
} }
{
"batchsize": 1,
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 40,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"vae_latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"vae_z_dim": 16,
"feature_caching": "NoCaching",
"transformer_in_channels": 64,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"_auto_resize": true,
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
"axes_dims_rope": [
16,
56,
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": true,
"dit_quantized": true,
"dit_quantized_ckpt": "/path/to/qwen_2509_fp8.safetensors",
"dit_quant_scheme": "fp8-sgl"
}
...@@ -81,6 +81,5 @@ ...@@ -81,6 +81,5 @@
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3", "attn_type": "flash_attn3",
"do_true_cfg": true, "do_true_cfg": true,
"true_cfg_scale": 4.0, "true_cfg_scale": 4.0
"mm_config": {}
} }
...@@ -63,6 +63,17 @@ class WeightModule: ...@@ -63,6 +63,17 @@ class WeightModule:
module.state_dict(destination) module.state_dict(destination)
return destination return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if destination is None:
destination = {}
for _, param in self._parameters.items():
if param is not None:
param.load_state_dict(destination, block_index, adapter_block_index)
for _, module in self._modules.items():
if module is not None:
module.load_state_dict(destination, block_index, adapter_block_index)
return destination
def named_parameters(self, prefix=""): def named_parameters(self, prefix=""):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
......
...@@ -9,62 +9,62 @@ from loguru import logger ...@@ -9,62 +9,62 @@ from loguru import logger
class WeightAsyncStreamManager(object): class WeightAsyncStreamManager(object):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1): def __init__(self, offload_granularity):
self.init(blocks_num, phases_num, offload_ratio) self.offload_granularity = offload_granularity
self.compute_stream = torch.cuda.Stream(priority=-1) self.init_stream = torch.cuda.Stream(priority=0)
self.cpu_load_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=0) self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.compute_stream = torch.cuda.Stream(priority=-1)
def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
if self.offload_granularity == "block":
assert blocks_cuda_buffer is not None
self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
elif self.offload_granularity == "phase":
assert phases_cuda_buffer is not None
self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
else:
raise NotImplementedError
def init(self, blocks_num, phases_num, offload_ratio): def init_first_buffer(self, blocks, adapter_block_idx=None):
if hasattr(self, "active_weights"): if self.offload_granularity == "block":
del self.active_weights[:] with torch.cuda.stream(self.init_stream):
self.active_weights = [None for _ in range(3)] self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
self.blocks_num = blocks_num else:
self.phases_num = phases_num with torch.cuda.stream(self.init_stream):
self.offload_ratio = offload_ratio self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
self.offload_blocks_num = int(self.offload_ratio * self.blocks_num) self.init_stream.synchronize()
self.offload_phases_num = self.blocks_num * self.phases_num * self.offload_ratio
def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
def prefetch_weights(self, block_idx, blocks_weights):
with torch.cuda.stream(self.cuda_load_stream): with torch.cuda.stream(self.cuda_load_stream):
self.active_weights[2] = blocks_weights[block_idx] self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
self.active_weights[2].to_cuda_async()
with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_blocks_num:
if self.active_weights[1] is not None:
self.active_weights[1].to_cpu_async()
def swap_weights(self): def swap_blocks(self):
self.compute_stream.synchronize()
self.cpu_load_stream.synchronize()
self.cuda_load_stream.synchronize() self.cuda_load_stream.synchronize()
self.compute_stream.synchronize()
self.active_weights[0], self.active_weights[1] = ( self.cuda_buffers[0], self.cuda_buffers[1] = (
self.active_weights[2], self.cuda_buffers[1],
self.active_weights[0], self.cuda_buffers[0],
) )
def prefetch_phase(self, block_idx, phase_idx, blocks): def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
with torch.cuda.stream(self.cuda_load_stream): with torch.cuda.stream(self.cuda_load_stream):
new_phase = blocks[block_idx].compute_phases[phase_idx] self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
new_phase.to_cuda_async()
self.active_weights[2] = (phase_idx, blocks[block_idx].compute_phases[phase_idx])
with torch.cuda.stream(self.cpu_load_stream):
if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
if self.active_weights[1] is not None:
_, old_phase = self.active_weights[1]
old_phase.to_cpu_async()
def swap_phases(self): def swap_phases(self):
self.compute_stream.synchronize()
self.cpu_load_stream.synchronize()
self.cuda_load_stream.synchronize() self.cuda_load_stream.synchronize()
self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0] self.compute_stream.synchronize()
self.active_weights[2] = None
class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2, offload_gra="phase"): def __init__(
self,
blocks_num,
offload_ratio=1,
phases_num=1,
num_disk_workers=1,
max_memory=2,
offload_gra="phase",
):
super().__init__(blocks_num, offload_ratio, phases_num) super().__init__(blocks_num, offload_ratio, phases_num)
self.offload_gra = offload_gra self.offload_gra = offload_gra
self.worker_stop_event = threading.Event() self.worker_stop_event = threading.Event()
...@@ -220,12 +220,12 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -220,12 +220,12 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with torch.cuda.stream(self.cuda_load_stream): with torch.cuda.stream(self.cuda_load_stream):
block = self.pin_memory_buffer.get(obj_key) block = self.pin_memory_buffer.get(obj_key)
block.to_cuda_async() block.to_cuda_async()
self.active_weights[2] = (obj_key, block) self.cuda_buffers[2] = (obj_key, block)
with torch.cuda.stream(self.cpu_load_stream): with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_blocks_num: if block_idx < self.offload_blocks_num:
if self.active_weights[1] is not None: if self.cuda_buffers[1] is not None:
old_key, old_block = self.active_weights[1] old_key, old_block = self.cuda_buffers[1]
if self.pin_memory_buffer.exists(old_key): if self.pin_memory_buffer.exists(old_key):
old_block.to_cpu_async() old_block.to_cpu_async()
self.pin_memory_buffer.pop(old_key) self.pin_memory_buffer.pop(old_key)
...@@ -258,12 +258,12 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -258,12 +258,12 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with torch.cuda.stream(self.cuda_load_stream): with torch.cuda.stream(self.cuda_load_stream):
phase = self.pin_memory_buffer.get(obj_key) phase = self.pin_memory_buffer.get(obj_key)
phase.to_cuda_async() phase.to_cuda_async()
self.active_weights[2] = (obj_key, phase) self.cuda_buffers[2] = (obj_key, phase)
with torch.cuda.stream(self.cpu_load_stream): with torch.cuda.stream(self.cpu_load_stream):
if block_idx * self.phases_num + phase_idx < self.offload_phases_num: if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
if self.active_weights[1] is not None: if self.cuda_buffers[1] is not None:
old_key, old_phase = self.active_weights[1] old_key, old_phase = self.cuda_buffers[1]
if self.pin_memory_buffer.exists(old_key): if self.pin_memory_buffer.exists(old_key):
old_phase.to_cpu_async() old_phase.to_cpu_async()
self.pin_memory_buffer.pop(old_key) self.pin_memory_buffer.pop(old_key)
......
...@@ -27,3 +27,6 @@ class AttnWeightTemplate(metaclass=ABCMeta): ...@@ -27,3 +27,6 @@ class AttnWeightTemplate(metaclass=ABCMeta):
if destination is None: if destination is None:
destination = {} destination = {}
return destination return destination
def load_state_dict(self, destination, block_index, adapter_block_inde=None):
return {}
...@@ -34,10 +34,30 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -34,10 +34,30 @@ class Conv3dWeight(Conv3dWeightTemplate):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups) super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name] device = weight_dict[self.weight_name].device
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None if device.type == "cuda":
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype) self.weight = weight_dict[self.weight_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias_name is not None else None if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
self.pin_bias = None
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv3d( input_tensor = torch.nn.functional.conv3d(
...@@ -51,22 +71,27 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -51,22 +71,27 @@ class Conv3dWeight(Conv3dWeightTemplate):
) )
return input_tensor return input_tensor
def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False): def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking) self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
if self.bias is not None: if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None): def state_dict(self, destination=None):
if destination is None: if destination is None:
destination = {} destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone() destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight # .cpu().detach().clone().contiguous()
if self.bias is not None: if self.bias_name is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias # .cpu().detach().clone()
return destination return destination
def clear(self): def clear(self):
......
import re
from abc import ABCMeta from abc import ABCMeta
import torch import torch
...@@ -7,30 +8,57 @@ from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER ...@@ -7,30 +8,57 @@ from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
class EmbeddingWeightTemplate(metaclass=ABCMeta): class EmbeddingWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name self.weight_name = weight_name
self.create_cuda_buffer = create_cuda_buffer
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {} self.config = {}
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
if self.weight_name is not None: if self.create_cuda_buffer:
self.weight = weight_dict[self.weight_name] self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
else: else:
self.weight = None device = weight_dict[self.weight_name].device
if device.type == "cuda":
self.weight = weight_dict[self.weight_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
del weight_dict[self.weight_name] def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_weight"): if hasattr(self, "pin_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu() self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else: else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking) self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False): def state_dict(self, destination=None):
self.weight = self.weight.cuda(non_blocking=non_blocking) if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
@EMBEDDING_WEIGHT_REGISTER("Default") @EMBEDDING_WEIGHT_REGISTER("Default")
......
This diff is collapsed.
import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
...@@ -7,28 +8,54 @@ from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER ...@@ -7,28 +8,54 @@ from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
class LNWeightTemplate(metaclass=ABCMeta): class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name=None, bias_name=None, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
self.weight_name = weight_name self.weight_name = weight_name
self.bias_name = bias_name self.bias_name = bias_name
self.eps = eps self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {} self.config = {}
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
if self.weight_name is not None: if self.create_cuda_buffer:
self.weight = weight_dict[self.weight_name] if self.weight_name is not None:
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype) self.weight_cuda_buffer = weight_dict[self.weight_name].cuda().t()
if self.bias_name is not None:
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else: else:
self.weight = None if self.weight_name is not None:
if self.bias_name is not None: device = weight_dict[self.weight_name].device
self.bias = weight_dict[self.bias_name] if device.type == "cuda":
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) self.weight = weight_dict[self.weight_name]
else: if self.bias_name is not None:
self.bias = None self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
self.pin_bias = None
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else:
self.weight = None
self.bias = None
def _calculate_size(self): def _calculate_size(self):
if self.weight is None: if self.weight is None:
...@@ -52,37 +79,61 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -52,37 +79,61 @@ class LNWeightTemplate(metaclass=ABCMeta):
if config is not None: if config is not None:
self.config = config self.config = config
def to_cpu(self, non_blocking=False): def to_cuda(self, non_blocking=False):
if hasattr(self, "pinned_weight"): if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu() self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
if self.bias is not None: else:
self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu() self.weight = None
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
else: else:
if self.weight is not None: self.bias = None
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if self.bias is not None: if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
elif hasattr(self, "weight") and self.weight is not None:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking) self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
if self.weight is not None:
self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None): def state_dict(self, destination=None):
if destination is None: if destination is None:
destination = {} destination = {}
if self.weight is not None: if self.weight_name is not None:
destination[self.weight_name] = self.weight.cpu().detach().clone() destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias is not None: if self.bias_name is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
return destination return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.weight_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
else:
self.weight = None
if self.bias_name is not None:
bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
@LN_WEIGHT_REGISTER("Default") @LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate): class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self): def load_from_disk(self):
if self.weight_name is not None: if self.weight_name is not None:
......
import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
...@@ -12,19 +13,33 @@ except ImportError: ...@@ -12,19 +13,33 @@ except ImportError:
class RMSWeightTemplate(metaclass=ABCMeta): class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
self.weight_name = weight_name self.weight_name = weight_name
self.eps = eps self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.config = {} self.config = {}
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
self.weight = weight_dict[self.weight_name] if self.create_cuda_buffer:
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype) self.weight_cuda_buffer = weight_dict[self.weight_name].cuda()
else:
device = weight_dict[self.weight_name].device
if device.type == "cuda":
self.weight = weight_dict[self.weight_name]
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
def clear(self): def clear(self):
attrs = ["weight", "pinned_weight"] attrs = ["weight", "pinned_weight"]
...@@ -41,28 +56,23 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -41,28 +56,23 @@ class RMSWeightTemplate(metaclass=ABCMeta):
if config is not None: if config is not None:
self.config = config self.config = config
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_weight"): if hasattr(self, "pin_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu() self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else: else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking) self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
def _calculate_size(self): def _calculate_size(self):
return self.weight.numel() * self.weight.element_size() return self.weight.numel() * self.weight.element_size()
@RMS_WEIGHT_REGISTER("Default") @RMS_WEIGHT_REGISTER("Default")
class RMSWeight(RMSWeightTemplate): class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load(self, weight_dict):
if not self.lazy_load:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
def load_from_disk(self): def load_from_disk(self):
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
...@@ -83,19 +93,26 @@ class RMSWeight(RMSWeightTemplate): ...@@ -83,19 +93,26 @@ class RMSWeight(RMSWeightTemplate):
def state_dict(self, destination=None): def state_dict(self, destination=None):
if destination is None: if destination is None:
destination = {} destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone() destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
return destination return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
@RMS_WEIGHT_REGISTER("sgl-kernel") @RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight): class RMSWeightSgl(RMSWeight):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load(self, weight_dict):
if not self.lazy_load:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
def load_from_disk(self): def load_from_disk(self):
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
...@@ -123,8 +140,8 @@ class RMSWeightSgl(RMSWeight): ...@@ -123,8 +140,8 @@ class RMSWeightSgl(RMSWeight):
@RMS_WEIGHT_REGISTER("fp32_variance") @RMS_WEIGHT_REGISTER("fp32_variance")
class RMSWeightFP32(RMSWeight): class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor): def apply(self, input_tensor):
input_dtype = input_tensor.dtype input_dtype = input_tensor.dtype
...@@ -142,8 +159,8 @@ class RMSWeightFP32(RMSWeight): ...@@ -142,8 +159,8 @@ class RMSWeightFP32(RMSWeight):
@RMS_WEIGHT_REGISTER("self_forcing") @RMS_WEIGHT_REGISTER("self_forcing")
class RMSWeightSF(RMSWeight): class RMSWeightSF(RMSWeight):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def _norm(self, x): def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
......
import re
import torch import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -6,10 +8,12 @@ from lightx2v.utils.registry_factory import TENSOR_REGISTER ...@@ -6,10 +8,12 @@ from lightx2v.utils.registry_factory import TENSOR_REGISTER
@TENSOR_REGISTER("Default") @TENSOR_REGISTER("Default")
class DefaultTensor: class DefaultTensor:
def __init__(self, tensor_name, lazy_load=False, lazy_load_file=None): def __init__(self, tensor_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.tensor_name = tensor_name self.tensor_name = tensor_name
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.create_cuda_buffer = create_cuda_buffer
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
...@@ -21,8 +25,20 @@ class DefaultTensor: ...@@ -21,8 +25,20 @@ class DefaultTensor:
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
self.tensor = weight_dict[self.tensor_name] if self.create_cuda_buffer:
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype) self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda()
else:
device = weight_dict[self.tensor_name].device
if device.type == "cuda":
self.tensor = weight_dict[self.tensor_name]
elif device.type == "cpu":
tensor_shape = weight_dict[self.tensor_name].shape
tensor_dtype = weight_dict[self.tensor_name].dtype
self.pin_tensor = torch.empty(tensor_shape, pin_memory=True, dtype=tensor_dtype)
self.pin_tensor.copy_(weight_dict[self.tensor_name])
del weight_dict[self.tensor_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
def clear(self): def clear(self):
attrs = ["tensor", "pinned_tensor"] attrs = ["tensor", "pinned_tensor"]
...@@ -34,17 +50,29 @@ class DefaultTensor: ...@@ -34,17 +50,29 @@ class DefaultTensor:
def _calculate_size(self): def _calculate_size(self):
return self.tensor.numel() * self.tensor.element_size() return self.tensor.numel() * self.tensor.element_size()
def to_cuda(self, non_blocking=False):
self.tensor = self.pin_tensor.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_tensor"): if hasattr(self, "pin_tensor"):
self.tensor = self.pinned_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu() self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
else: else:
self.tensor = self.tensor.to("cpu", non_blocking=non_blocking) self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.tensor = self.tensor.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None): def state_dict(self, destination=None):
if destination is None: if destination is None:
destination = {} destination = {}
destination[self.tensor_name] = self.tensor.cpu().detach().clone() destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
return destination return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
if tensor_name not in destination:
self.tensor = None
return
self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
...@@ -48,12 +48,14 @@ class T5OffloadBlocksWeights(WeightModule): ...@@ -48,12 +48,14 @@ class T5OffloadBlocksWeights(WeightModule):
def __init__(self, block_nums, mm_type): def __init__(self, block_nums, mm_type):
super().__init__() super().__init__()
self.block_nums = block_nums self.block_nums = block_nums
self.offload_block_buffers = WeightModuleList([T5OffloadSelfAttention(i, mm_type, create_cuda_buffer=True) for i in range(2)])
self.blocks = WeightModuleList([T5OffloadSelfAttention(i, mm_type) for i in range(block_nums)]) self.blocks = WeightModuleList([T5OffloadSelfAttention(i, mm_type) for i in range(block_nums)])
self.add_module("offload_block_buffers", self.offload_block_buffers)
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
class T5OffloadSelfAttention(WeightModule): class T5OffloadSelfAttention(WeightModule):
def __init__(self, block_index, mm_type, block_prefix="blocks"): def __init__(self, block_index, mm_type, block_prefix="blocks", create_cuda_buffer=False):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
if mm_type is None: if mm_type is None:
...@@ -62,81 +64,66 @@ class T5OffloadSelfAttention(WeightModule): ...@@ -62,81 +64,66 @@ class T5OffloadSelfAttention(WeightModule):
self.add_module( self.add_module(
"norm1", "norm1",
RMS_WEIGHT_REGISTER["sgl-kernel"]( RMS_WEIGHT_REGISTER["sgl-kernel"](f"{block_prefix}.{self.block_index}.norm1.weight", create_cuda_buffer),
f"{block_prefix}.{self.block_index}.norm1.weight",
),
) )
self.add_module( self.add_module(
"norm2", "norm2",
RMS_WEIGHT_REGISTER["sgl-kernel"]( RMS_WEIGHT_REGISTER["sgl-kernel"](f"{block_prefix}.{self.block_index}.norm2.weight", create_cuda_buffer),
f"{block_prefix}.{self.block_index}.norm2.weight",
),
) )
self.add_module( self.add_module(
"pos_embedding", "pos_embedding",
EMBEDDING_WEIGHT_REGISTER["Default"]( EMBEDDING_WEIGHT_REGISTER["Default"](f"{block_prefix}.{self.block_index}.pos_embedding.embedding.weight", create_cuda_buffer),
f"{block_prefix}.{self.block_index}.pos_embedding.embedding.weight",
),
) )
self.compute_phases = WeightModuleList( self.compute_phases = WeightModuleList(
[ [
T5OffloadAttention( T5OffloadAttention(block_index, block_prefix, mm_type, create_cuda_buffer),
block_index, T5OffloadFeedForward(block_index, block_prefix, mm_type, create_cuda_buffer),
block_prefix,
mm_type,
),
T5OffloadFeedForward(
block_index,
block_prefix,
mm_type,
),
] ]
) )
self.add_module("compute_phases", self.compute_phases) self.add_module("compute_phases", self.compute_phases)
class T5OffloadAttention(WeightModule): class T5OffloadAttention(WeightModule):
def __init__(self, block_index, block_prefix, mm_type): def __init__(self, block_index, block_prefix, mm_type, create_cuda_buffer=False):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.add_module( self.add_module(
"attn_q", "attn_q",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.q.weight", None), MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.q.weight", None, create_cuda_buffer),
) )
self.add_module( self.add_module(
"attn_k", "attn_k",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.k.weight", None), MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.k.weight", None, create_cuda_buffer),
) )
self.add_module( self.add_module(
"attn_v", "attn_v",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.v.weight", None), MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.v.weight", None, create_cuda_buffer),
) )
self.add_module( self.add_module(
"attn_o", "attn_o",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.o.weight", None), MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.o.weight", None, create_cuda_buffer),
) )
class T5OffloadFeedForward(WeightModule): class T5OffloadFeedForward(WeightModule):
def __init__(self, block_index, block_prefix, mm_type): def __init__(self, block_index, block_prefix, mm_type, create_cuda_buffer=False):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.add_module( self.add_module(
"ffn_fc1", "ffn_fc1",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc1.weight", None), MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc1.weight", None, create_cuda_buffer),
) )
self.add_module( self.add_module(
"ffn_fc2", "ffn_fc2",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc2.weight", None), MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc2.weight", None, create_cuda_buffer),
) )
self.add_module( self.add_module(
"ffn_gate_0", "ffn_gate_0",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.gate.0.weight", None), MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.gate.0.weight", None, create_cuda_buffer),
) )
self.gelu = GELU() self.gelu = GELU()
...@@ -453,8 +440,9 @@ class T5Encoder(nn.Module): ...@@ -453,8 +440,9 @@ class T5Encoder(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
if cpu_offload: if cpu_offload:
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=num_layers) self.offload_manager = WeightAsyncStreamManager(offload_granularity="block")
self.blocks_weights = T5OffloadBlocksWeights(num_layers, quant_scheme) self.blocks_weights = T5OffloadBlocksWeights(num_layers, quant_scheme)
self.offload_manager.init_cuda_buffer(self.blocks_weights.offload_block_buffers, None)
self.blocks = self.blocks_weights.blocks self.blocks = self.blocks_weights.blocks
else: else:
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
...@@ -551,15 +539,17 @@ class T5Encoder(nn.Module): ...@@ -551,15 +539,17 @@ class T5Encoder(nn.Module):
for block_idx in range(len(self.blocks)): for block_idx in range(len(self.blocks)):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = self.blocks[0] self.offload_manager.cuda_buffers[0].load_state_dict(
self.weights_stream_mgr.active_weights[0].to_cuda() self.blocks[block_idx].state_dict(),
block_idx,
)
if block_idx < len(self.blocks) - 1: if block_idx < len(self.blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, self.blocks) self.offload_manager.prefetch_weights(block_idx + 1, self.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.forward_block_with_offload(self.blocks[block_idx], x, mask, pos_bias=e) x = self.forward_block_with_offload(self.offload_manager.cuda_buffers[0], x, mask, pos_bias=e)
self.weights_stream_mgr.swap_weights() self.offload_manager.swap_weights()
x = self.norm(x) x = self.norm(x)
x = self.dropout(x) x = self.dropout(x)
...@@ -826,10 +816,10 @@ if __name__ == "__main__": ...@@ -826,10 +816,10 @@ if __name__ == "__main__":
import time import time
checkpoint_dir = "" checkpoint_dir = ""
t5_checkpoint = "./models_t5_umt5-xxl-enc-bf16.pth" t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
t5_tokenizer = "./google/umt5-xxl" t5_tokenizer = "google/umt5-xxl"
cpu_offload = True cpu_offload = False
if cpu_offload: if cpu_offload:
device = torch.device("cpu") device = torch.device("cpu")
else: else:
......
...@@ -24,7 +24,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer): ...@@ -24,7 +24,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
assert NotImplementedError assert NotImplementedError
if offload_granularity != "model": if offload_granularity != "model":
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.num_blocks, offload_ratio=self.offload_ratio, phases_num=self.phases_num) self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
else: else:
assert NotImplementedError assert NotImplementedError
...@@ -32,16 +32,16 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer): ...@@ -32,16 +32,16 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
for block_idx in range(self.num_blocks): for block_idx in range(self.num_blocks):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = block_weights.blocks[0] self.offload_manager.init_first_buffer(block_weights.blocks)
self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < self.num_blocks - 1: if block_idx < self.num_blocks - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, block_weights.blocks) self.offload_manager.prefetch_weights(block_idx + 1, block_weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.offload_manager.compute_stream):
encoder_hidden_states, hidden_states = self.infer_block( encoder_hidden_states, hidden_states = self.infer_block(
block_weight=block_weights.blocks[block_idx], hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb block_weight=self.offload_manager.cuda_buffers[0], hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
) )
self.weights_stream_mgr.swap_weights() self.offload_manager.swap_blocks()
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
File mode changed from 100644 to 100755
import gc
import glob import glob
import json import json
import os import os
...@@ -34,8 +35,7 @@ class QwenImageTransformerModel: ...@@ -34,8 +35,7 @@ class QwenImageTransformerModel:
self.in_channels = transformer_config["in_channels"] self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {} self.attention_kwargs = {}
self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default" self.dit_quantized = self.config.get("dit_quantized", False)
self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
...@@ -63,15 +63,18 @@ class QwenImageTransformerModel: ...@@ -63,15 +63,18 @@ class QwenImageTransformerModel:
if weight_dict is None: if weight_dict is None:
is_weight_loader = self._should_load_weights() is_weight_loader = self._should_load_weights()
if is_weight_loader: if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized:
# Load original weights # Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else: else:
# Load quantized weights # Load quantized weights
assert NotImplementedError if not self.config.get("lazy_load", False):
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else:
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is not None: if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader) weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
else: else:
...@@ -81,36 +84,143 @@ class QwenImageTransformerModel: ...@@ -81,36 +84,143 @@ class QwenImageTransformerModel:
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config) self.post_weight = self.post_weight_class(self.config)
if not self._should_init_empty_model():
self._apply_weights()
def _apply_weights(self, weight_dict=None):
if weight_dict is not None:
self.original_weight_dict = weight_dict
del weight_dict
gc.collect()
# Load weights into containers # Load weights into containers
self.pre_weight.load(self.original_weight_dict) self.pre_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict) self.transformer_weights.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict) self.post_weight.load(self.original_weight_dict)
del self.original_weight_dict
torch.cuda.empty_cache()
gc.collect()
def _should_load_weights(self): def _should_load_weights(self):
"""Determine if current rank should load weights from disk.""" """Determine if current rank should load weights from disk."""
if self.config.get("device_mesh") is None: if self.config.get("device_mesh") is None:
# Single GPU mode # Single GPU mode
return True return True
elif dist.is_initialized(): elif dist.is_initialized():
# Multi-GPU mode, only rank 0 loads if self.config.get("load_from_rank0", False):
if dist.get_rank() == 0: # Multi-GPU mode, only rank 0 loads
logger.info(f"Loading weights from {self.model_path}") if dist.get_rank() == 0:
logger.info(f"Loading weights from {self.model_path}")
return True
else:
return True return True
return False return False
def _should_init_empty_model(self):
if self.config.get("lora_configs") and self.config["lora_configs"]:
return True
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt", device=str(self.device)) as f: remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
if self.device.type == "cuda" and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank()))
else:
device = self.device
with safe_open(file_path, framework="pt", device=str(device)) as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
for key in f.keys()
if not any(remove_key in key for remove_key in remove_keys)
}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors")) if self.config.get("dit_original_ckpt", None):
safetensors_path = self.config["dit_original_ckpt"]
else:
safetensors_path = self.model_path
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
logger.info(f"Loading weights from {file_path}")
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict
def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config.get("dit_quantized_ckpt", None):
safetensors_path = self.config["dit_quantized_ckpt"]
else:
safetensors_path = self.model_path
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path)
weight_dict = {}
for safetensor_path in safetensors_files:
with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
if any(remove_key in k for remove_key in remove_keys):
continue
if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).to(self.device)
if self.config.get("dit_quant_scheme", "Default") == "nvfp4":
calib_path = os.path.join(safetensors_path, "calib.pt")
logger.info(f"[CALIB] Loaded calibration data from: {calib_path}")
calib_data = torch.load(calib_path, map_location="cpu")
for k, v in calib_data["absmax"].items():
weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device)
return weight_dict return weight_dict
def _load_weights_distribute(self, weight_dict, is_weight_loader): def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite
lazy_load_model_path = self.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict = {}
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)
return pre_post_weight_dict
def _load_weights_from_rank0(self, weight_dict, is_weight_loader):
logger.info("Loading distributed weights")
global_src_rank = 0 global_src_rank = 0
target_device = "cpu" if self.cpu_offload else "cuda" target_device = "cpu" if self.cpu_offload else "cuda"
...@@ -165,12 +275,15 @@ class QwenImageTransformerModel: ...@@ -165,12 +275,15 @@ class QwenImageTransformerModel:
tensor.copy_(tensor, non_blocking=False) tensor.copy_(tensor, non_blocking=False)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict return distributed_weight_dict
def _init_infer(self): def _init_infer(self):
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
def to_cpu(self): def to_cpu(self):
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment