Unverified Commit 74eeb429 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

reconstruct disk offload and fix lightx2v_platform bugs (#558)


Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent f7cdbcb5
...@@ -4,6 +4,7 @@ from einops import rearrange ...@@ -4,6 +4,7 @@ from einops import rearrange
from torch.nn import functional as F from torch.nn import functional as F
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v_platform.base.global_var import AI_DEVICE
from .posemb_layers import get_nd_rotary_pos_embed from .posemb_layers import get_nd_rotary_pos_embed
......
...@@ -13,6 +13,10 @@ from diffusers.models.modeling_utils import ModelMixin ...@@ -13,6 +13,10 @@ from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
@dataclass @dataclass
class DecoderOutput(BaseOutput): class DecoderOutput(BaseOutput):
...@@ -725,7 +729,7 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin): ...@@ -725,7 +729,7 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
@torch.no_grad() @torch.no_grad()
def encode(self, x: Tensor, return_dict: bool = True): def encode(self, x: Tensor, return_dict: bool = True):
if self.cpu_offload: if self.cpu_offload:
self.encoder = self.encoder.to("cuda") self.encoder = self.encoder.to(AI_DEVICE)
def _encode(x): def _encode(x):
if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize: if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
...@@ -752,7 +756,7 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin): ...@@ -752,7 +756,7 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
@torch.no_grad() @torch.no_grad()
def decode(self, z: Tensor, return_dict: bool = True, generator=None): def decode(self, z: Tensor, return_dict: bool = True, generator=None):
if self.cpu_offload: if self.cpu_offload:
self.decoder = self.decoder.to("cuda") self.decoder = self.decoder.to(AI_DEVICE)
def _decode(z): def _decode(z):
if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize: if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
......
...@@ -873,8 +873,8 @@ class WanVAE: ...@@ -873,8 +873,8 @@ class WanVAE:
2.8251, 2.8251,
1.9160, 1.9160,
] ]
self.mean = torch.tensor(mean, dtype=dtype, device=device) self.mean = torch.tensor(mean, dtype=dtype, device=AI_DEVICE)
self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=device) self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=AI_DEVICE)
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
# (height, width, world_size) -> (world_size_h, world_size_w) # (height, width, world_size) -> (world_size_h, world_size_w)
......
...@@ -7,6 +7,9 @@ import torch.nn.functional as F ...@@ -7,6 +7,9 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from lightx2v.utils.utils import load_weights from lightx2v.utils.utils import load_weights
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
__all__ = [ __all__ = [
"Wan2_2_VAE", "Wan2_2_VAE",
...@@ -933,7 +936,7 @@ class Wan2_2_VAE: ...@@ -933,7 +936,7 @@ class Wan2_2_VAE:
-0.0667, -0.0667,
], ],
dtype=dtype, dtype=dtype,
device=device, device=AI_DEVICE,
) )
self.std = torch.tensor( self.std = torch.tensor(
[ [
...@@ -987,7 +990,7 @@ class Wan2_2_VAE: ...@@ -987,7 +990,7 @@ class Wan2_2_VAE:
0.7744, 0.7744,
], ],
dtype=dtype, dtype=dtype,
device=device, device=AI_DEVICE,
) )
self.inv_std = 1.0 / self.std self.inv_std = 1.0 / self.std
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
...@@ -1011,11 +1014,11 @@ class Wan2_2_VAE: ...@@ -1011,11 +1014,11 @@ class Wan2_2_VAE:
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def to_cuda(self): def to_cuda(self):
self.model.encoder = self.model.encoder.to("cuda") self.model.encoder = self.model.encoder.to(AI_DEVICE)
self.model.decoder = self.model.decoder.to("cuda") self.model.decoder = self.model.decoder.to(AI_DEVICE)
self.model = self.model.to("cuda") self.model = self.model.to(AI_DEVICE)
self.mean = self.mean.cuda() self.mean = self.mean.to(AI_DEVICE)
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.to(AI_DEVICE)
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def encode(self, video): def encode(self, video):
......
...@@ -3,6 +3,9 @@ import torch.nn as nn ...@@ -3,6 +3,9 @@ import torch.nn as nn
from einops import rearrange, repeat from einops import rearrange, repeat
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE_, _video_vae from lightx2v.models.video_encoders.hf.wan.vae import WanVAE_, _video_vae
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class WanSFVAE: class WanSFVAE:
...@@ -46,11 +49,11 @@ class WanSFVAE: ...@@ -46,11 +49,11 @@ class WanSFVAE:
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def to_cuda(self): def to_cuda(self):
self.model.encoder = self.model.encoder.to("cuda") self.model.encoder = self.model.encoder.to(AI_DEVICE)
self.model.decoder = self.model.decoder.to("cuda") self.model.decoder = self.model.decoder.to(AI_DEVICE)
self.model = self.model.to("cuda") self.model = self.model.to(AI_DEVICE)
self.mean = self.mean.cuda() self.mean = self.mean.to(AI_DEVICE)
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.to(AI_DEVICE)
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def decode(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor: def decode(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
......
...@@ -56,86 +56,132 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -56,86 +56,132 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.infer_dtype = torch.bfloat16 # bias dtype self.infer_dtype = torch.bfloat16 # bias dtype
self.bias_force_fp32 = False
# ========================= # =========================
# weight load functions # weight load functions
# ========================= # =========================
def load_from_disk(self): # Need Rewrite
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name)
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
if self.weight_need_transpose:
self.weight = self.weight.t()
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: self.load_quantized(weight_dict)
self.load_func(weight_dict)
if self.weight_need_transpose: if self.weight_need_transpose:
if hasattr(self, "weight"): if hasattr(self, "weight") and self.weight is not None:
self.weight = self.weight.t() self.weight = self.weight.t()
if hasattr(self, "pin_weight"): if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.pin_weight = self.pin_weight.t() self.pin_weight = self.pin_weight.t()
if hasattr(self, "weight_cuda_buffer"): if hasattr(self, "weight_cuda_buffer") and self.weight_cuda_buffer is not None:
self.weight_cuda_buffer = self.weight_cuda_buffer.t() self.weight_cuda_buffer = self.weight_cuda_buffer.t()
def clear(self):
attrs = ["weight", "weight_scale", "bias", "pin_weight", "pin_weight_scale", "pin_bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
def _calculate_size(self):
if self.bias is not None:
return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() + self.bias.numel() * self.bias.element_size()
return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()
def load_quantized(self, weight_dict): def load_quantized(self, weight_dict):
if self.create_cuda_buffer: if self.create_cuda_buffer:
# move to cuda buffer self._load_cuda_buffers(weight_dict)
self.weight_cuda_buffer = weight_dict[self.weight_name].cuda() elif self.create_cpu_buffer:
self.weight_scale_cuda_buffer = weight_dict[self.weight_scale_name].float().cuda() self._load_cpu_pin_buffers()
else: else:
device = weight_dict[self.weight_name].device self._load_default_tensors(weight_dict)
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape def _load_cuda_buffers(self, weight_dict):
weight_scale_dtype = torch.float source = self.lazy_load_file if self.lazy_load else weight_dict
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name]) self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
del weight_dict[self.weight_name]
def _get_cuda_tensor_pair(self, source, is_lazy):
if is_lazy:
weight = source.get_tensor(self.weight_name).to(AI_DEVICE)
scale = source.get_tensor(self.weight_scale_name).float().to(AI_DEVICE)
else: else:
self.weight = weight_dict[self.weight_name] weight = source[self.weight_name].to(AI_DEVICE)
self.weight_scale = weight_dict[self.weight_scale_name].float() scale = source[self.weight_scale_name].float().to(AI_DEVICE)
return weight, scale
def _get_cuda_bias_tensor(self, source, is_lazy):
if self.bias_name is None:
return None
if is_lazy:
bias = source.get_tensor(self.bias_name)
dtype = self.infer_dtype
else:
bias = source[self.bias_name]
dtype = bias.dtype
if self.bias_force_fp32:
bias = bias.to(torch.float32)
else:
bias = bias.to(dtype)
return bias.to(AI_DEVICE)
if self.bias_name is not None: def _load_cpu_pin_buffers(self):
if self.create_cuda_buffer: self.pin_weight, self.pin_weight_scale = self._get_cpu_pin_tensor_pair(self.lazy_load_file, is_lazy=True)
# move to cuda buffer self.pin_bias = self._get_cpu_pin_bias_tensor(self.lazy_load_file, is_lazy=True)
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda() self.bias = None
def _get_cpu_pin_tensor_pair(self, source, is_lazy):
if is_lazy:
weight_tensor = source.get_tensor(self.weight_name)
scale_tensor = source.get_tensor(self.weight_scale_name)
scale_dtype = torch.float
else: else:
device = weight_dict[self.bias_name].device weight_tensor = source[self.weight_name]
scale_tensor = source[self.weight_scale_name]
scale_dtype = torch.float
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
return pin_weight, pin_scale
def _get_cpu_pin_bias_tensor(self, source, is_lazy):
if self.bias_name is None:
return None
if is_lazy:
bias_tensor = source.get_tensor(self.bias_name)
if not self.bias_force_fp32:
bias_tensor = bias_tensor.to(self.infer_dtype)
else:
bias_tensor = source[self.bias_name]
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
def _create_pin_tensor(self, tensor, dtype=None):
dtype = dtype or tensor.dtype
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
self.weight, self.weight_scale, self.pin_weight, self.pin_weight_scale = self._get_device_tensor_pair(weight_dict)
self._load_default_bias(weight_dict)
else:
self.bias = None
self.pin_bias = None
def _get_device_tensor_pair(self, source):
device = source[self.weight_name].device
if device.type == "cpu": if device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape pin_weight, pin_scale = self._get_cpu_pin_tensor_pair(source, is_lazy=False)
bias_dtype = weight_dict[self.bias_name].dtype return None, None, pin_weight, pin_scale
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = weight_dict[self.bias_name] return source[self.weight_name], source[self.weight_scale_name].float(), None, None
def _load_default_bias(self, source):
if self.bias_name is None:
self.bias = None
self.pin_bias = None
self.bias_cuda_buffer = None
return
if self.create_cuda_buffer:
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, is_lazy=False)
self.bias = None
self.pin_bias = None
else: else:
bias_tensor = source[self.bias_name].float() if self.bias_force_fp32 else source[self.bias_name]
device = bias_tensor.device
if device.type == "cpu":
self.pin_bias = self._get_cpu_pin_bias_tensor(source, is_lazy=False)
self.bias = None self.bias = None
else:
self.bias = bias_tensor
self.pin_bias = None self.pin_bias = None
def load_fp8_perchannel_sym(self, weight_dict): def load_fp8_perchannel_sym(self, weight_dict):
...@@ -161,7 +207,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -161,7 +207,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load_mxfp4(self, weight_dict): def load_mxfp4(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16) self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight) self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
...@@ -184,7 +230,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -184,7 +230,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load_mxfp6(self, weight_dict): def load_mxfp6(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16) self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight) self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
...@@ -207,7 +253,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -207,7 +253,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def load_mxfp8(self, weight_dict): def load_mxfp8(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].cuda().to(torch.bfloat16) self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight) self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device) self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else: else:
...@@ -265,19 +311,16 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -265,19 +311,16 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.bias_name is not None: if self.bias_name is not None:
if self.create_cuda_buffer: if self.create_cuda_buffer:
# move to cuda buffer self.bias_cuda_buffer = weight_dict[self.bias_name].to(AI_DEVICE)
self.bias_cuda_buffer = weight_dict[self.bias_name].cuda()
else: else:
device = weight_dict[self.bias_name].device device = weight_dict[self.bias_name].device
if device.type == "cuda": if device.type == "cpu":
self.bias = weight_dict[self.bias_name]
elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name]) self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") self.bias = weight_dict[self.bias_name]
else: else:
self.bias = None self.bias = None
self.pin_bias = None self.pin_bias = None
...@@ -388,3 +431,33 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -388,3 +431,33 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True) self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else: else:
self.bias = None self.bias = None
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)
if self.weight_need_transpose:
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
weight_scale_tensor = self.lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_tensor
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment