Commit e74270f5 authored by Xinchi Huang's avatar Xinchi Huang Committed by GitHub
Browse files

fix offload & peak memory decorator



* fix offload

* fix offload

---------
Co-authored-by: default avatar“de1star” <“843414674@qq.com”>
parent c8606815
...@@ -56,8 +56,8 @@ class MMWeight(MMWeightTemplate): ...@@ -56,8 +56,8 @@ class MMWeight(MMWeightTemplate):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t().cuda() self.weight = weight_dict[self.weight_name].t()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
...@@ -106,38 +106,38 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -106,38 +106,38 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight = self.weight.t() self.weight = self.weight.t()
def load_quantized(self, weight_dict): def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_name.removesuffix(".weight") + ".weight_scale"].float().cuda() self.weight_scale = weight_dict[self.weight_name.removesuffix(".weight") + ".weight_scale"].float()
def load_fp8_perchannel_sym(self, weight_dict): def load_fp8_perchannel_sym(self, weight_dict):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32)
w_quantizer = FloatQuantizer("e4m3", True, "per_channel") w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn) self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
def load_int8_perchannel_sym(self, weight_dict): def load_int8_perchannel_sym(self, weight_dict):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32)
w_quantizer = IntegerQuantizer(8, True, "per_channel") w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8) self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
def load_fp8_perblock128_sym(self, weight_dict): def load_fp8_perblock128_sym(self, weight_dict):
if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False): if GET_RUNNING_FLAG() == "save_naive_quant" or self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name]
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight) self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
def per_block_cast_to_fp8(self, x): def per_block_cast_to_fp8(self, x):
assert x.dim() == 2 assert x.dim() == 2
......
...@@ -8,6 +8,7 @@ import torch.nn as nn ...@@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
__all__ = [ __all__ = [
...@@ -460,6 +461,7 @@ def umt5_xxl(**kwargs): ...@@ -460,6 +461,7 @@ def umt5_xxl(**kwargs):
class T5EncoderModel: class T5EncoderModel:
@peak_memory_decorator
def __init__( def __init__(
self, self,
text_len, text_len,
......
...@@ -10,6 +10,7 @@ import torchvision.transforms as T ...@@ -10,6 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
from .xlm_roberta import XLMRoberta from .xlm_roberta import XLMRoberta
...@@ -428,6 +429,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -428,6 +429,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
@peak_memory_decorator
def __init__(self, dtype, device, checkpoint_path, tokenizer_path): def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
......
...@@ -3,9 +3,11 @@ import torch ...@@ -3,9 +3,11 @@ import torch
from safetensors import safe_open from safetensors import safe_open
from loguru import logger from loguru import logger
import gc import gc
from lightx2v.utils.memory_profiler import peak_memory_decorator
class WanLoraWrapper: class WanLoraWrapper:
@peak_memory_decorator
def __init__(self, wan_model): def __init__(self, wan_model):
self.model = wan_model self.model = wan_model
self.lora_metadata = {} self.lora_metadata = {}
......
...@@ -20,6 +20,7 @@ from safetensors import safe_open ...@@ -20,6 +20,7 @@ from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
...@@ -28,6 +29,7 @@ class WanModel: ...@@ -28,6 +29,7 @@ class WanModel:
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
@peak_memory_decorator
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
...@@ -52,6 +54,8 @@ class WanModel: ...@@ -52,6 +54,8 @@ class WanModel:
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.to_cpu() self.to_cpu()
else:
self.to_cuda()
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
......
...@@ -5,6 +5,7 @@ from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext ...@@ -5,6 +5,7 @@ from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.prompt_enhancer import PromptEnhancer from lightx2v.utils.prompt_enhancer import PromptEnhancer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
...@@ -45,6 +46,7 @@ class DefaultRunner: ...@@ -45,6 +46,7 @@ class DefaultRunner:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@peak_memory_decorator
def run(self): def run(self):
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
...@@ -74,6 +76,7 @@ class DefaultRunner: ...@@ -74,6 +76,7 @@ class DefaultRunner:
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ProfilingContext("Run VAE") @ProfilingContext("Run VAE")
@peak_memory_decorator
def run_vae(self, latents, generator): def run_vae(self, latents, generator):
images = self.vae_model.decode(latents, generator=generator, config=self.config) images = self.vae_model.decode(latents, generator=generator, config=self.config)
return images return images
......
...@@ -16,6 +16,7 @@ from lightx2v.models.networks.wan.model import WanModel ...@@ -16,6 +16,7 @@ from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
import torch.distributed as dist import torch.distributed as dist
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
...@@ -79,6 +80,7 @@ class WanRunner(DefaultRunner): ...@@ -79,6 +80,7 @@ class WanRunner(DefaultRunner):
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
@peak_memory_decorator
def run_text_encoder(self, text, text_encoders, config, image_encoder_output): def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
text_encoder_output = {} text_encoder_output = {}
n_prompt = config.get("negative_prompt", "") n_prompt = config.get("negative_prompt", "")
...@@ -88,6 +90,7 @@ class WanRunner(DefaultRunner): ...@@ -88,6 +90,7 @@ class WanRunner(DefaultRunner):
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
return text_encoder_output return text_encoder_output
@peak_memory_decorator
def run_image_encoder(self, config, image_encoder, vae_model): def run_image_encoder(self, config, image_encoder, vae_model):
img = Image.open(config.image_path).convert("RGB") img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
from einops import rearrange from einops import rearrange
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
__all__ = [ __all__ = [
...@@ -651,6 +652,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): ...@@ -651,6 +652,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE: class WanVAE:
@peak_memory_decorator
def __init__( def __init__(
self, self,
z_dim=16, z_dim=16,
......
import torch
from loguru import logger
def peak_memory_decorator(func):
def wrapper(*args, **kwargs):
# 检查是否在分布式环境中
rank_info = ""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
rank_info = f"Rank {rank} - "
# 如果使用GPU,重置显存统计
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# 执行目标函数
result = func(*args, **kwargs)
# 获取峰值显存
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"{rank_info}Function '{func.__qualname__}' Peak Memory: {peak_memory:.2f} GB")
else:
logger.info(f"{rank_info}Function '{func.__qualname__}' executed without GPU.")
return result
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment