Commit 7a111e37 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix moe offload bug (#330)

parent c6be06a6
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
"sample_shift": 16, "sample_shift": 16,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "block",
"t5_cpu_offload": false, "t5_cpu_offload": false,
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
"sample_shift": 16, "sample_shift": 16,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "block",
"t5_cpu_offload": false, "t5_cpu_offload": false,
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
"sample_shift": 16, "sample_shift": 16,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "block",
"t5_cpu_offload": false, "t5_cpu_offload": false,
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
......
...@@ -10,9 +10,11 @@ ...@@ -10,9 +10,11 @@
"seed": 42, "seed": 42,
"sample_guide_scale": [3.5, 3.5], "sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": true, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "phase", "offload_granularity": "phase",
"boundary": 0.900, "boundary": 0.900,
"use_image_encoder": false "use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250]
} }
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
"infer_steps": 4, "infer_steps": 4,
"target_video_length": 81, "target_video_length": 81,
"text_len": 512, "text_len": 512,
"target_height": 480, "target_height": 720,
"target_width": 832, "target_width": 1280,
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
...@@ -12,17 +12,10 @@ ...@@ -12,17 +12,10 @@
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "block",
"t5_cpu_offload": false, "t5_cpu_offload": false,
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
"boundary_step_index": 2, "boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250], "denoising_step_list": [1000, 750, 500, 250]
"lora_configs": [
{
"name": "low_noise_model",
"path": "Wan2.1-I2V-14B-480P/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors",
"strength": 1.0
}
]
} }
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
"infer_steps": 4, "infer_steps": 4,
"target_video_length": 81, "target_video_length": 81,
"text_len": 512, "text_len": 512,
"target_height": 480, "target_height": 720,
"target_width": 832, "target_width": 1280,
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "block",
"t5_cpu_offload": false, "t5_cpu_offload": false,
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
......
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
import torch.distributed as dist
from loguru import logger from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -64,18 +63,25 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -64,18 +63,25 @@ class MMWeightTemplate(metaclass=ABCMeta):
self.config = config self.config = config
def to_cuda(self, non_blocking=False): def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking) self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
if hasattr(self, "weight_scale"): if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.weight_scale.cuda(non_blocking=non_blocking) self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None: if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking) if hasattr(self, "pin_weight"):
if hasattr(self, "weight_scale"): self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking) if hasattr(self, "weight_scale_name"):
if hasattr(self, "bias") and self.bias is not None: self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
self.bias = self.bias.to("cpu", non_blocking=non_blocking) if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
@MM_WEIGHT_REGISTER("Default") @MM_WEIGHT_REGISTER("Default")
...@@ -92,16 +98,16 @@ class MMWeight(MMWeightTemplate): ...@@ -92,16 +98,16 @@ class MMWeight(MMWeightTemplate):
elif device.type == "cpu": elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].t().shape weight_shape = weight_dict[self.weight_name].t().shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.weight.copy_(weight_dict[self.weight_name].t()) self.pin_weight.copy_(weight_dict[self.weight_name].t())
if self.bias_name is not None: if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name]) self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None self.pin_bias = None
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
...@@ -176,10 +182,13 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -176,10 +182,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if not self.lazy_load: if not self.lazy_load:
self.load_func(weight_dict) self.load_func(weight_dict)
if self.weight_need_transpose: if self.weight_need_transpose:
self.weight = self.weight.t() if hasattr(self, "weight"):
self.weight = self.weight.t()
elif hasattr(self, "pin_weight"):
self.pin_weight = self.pin_weight.t()
def clear(self): def clear(self):
attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"] attrs = ["weight", "weight_scale", "bias", "pin_weight", "pin_weight_scale", "pin_bias"]
for attr in attrs: for attr in attrs:
if hasattr(self, attr): if hasattr(self, attr):
delattr(self, attr) delattr(self, attr)
...@@ -198,15 +207,14 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -198,15 +207,14 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif device.type == "cpu": elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype) self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.weight.copy_(weight_dict[self.weight_name]) self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = torch.float weight_scale_dtype = torch.float
self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.weight_scale.copy_(weight_dict[self.weight_scale_name]) self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
if dist.is_initialized(): del weight_dict[self.weight_name]
del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
...@@ -227,12 +235,13 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -227,12 +235,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif device.type == "cpu": elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name]) self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else: else:
self.bias = None self.bias = None
self.pin_bias = None
def load_int8_perchannel_sym(self, weight_dict): def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
...@@ -251,12 +260,13 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -251,12 +260,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif device.type == "cpu": elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name]) self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else: else:
self.bias = None self.bias = None
self.pin_bias = None
def load_fp8_perblock128_sym(self, weight_dict): def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
...@@ -272,12 +282,13 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -272,12 +282,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif device.type == "cpu": elif device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name]) self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
else: else:
self.bias = None self.bias = None
self.pin_bias = None
def per_block_cast_to_fp8(self, x): def per_block_cast_to_fp8(self, x):
assert x.dim() == 2 assert x.dim() == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment