Commit f2a586f1 authored by Xinchi Huang's avatar Xinchi Huang Committed by GitHub
Browse files

async offload & offload ratio & context4debug (#44)



* async offload & context4debug

* offload ratio

* Merge branch 'main' into xinchi/fix_offload

* adding offload ratio

* pre-commit

---------
Co-authored-by: default avatar“de1star” <“843414674@qq.com”>
parent 8421734f
import torch import torch
import torch._dynamo as dynamo
import torch.distributed as dist import torch.distributed as dist
@dynamo.disable
def all2all_seq2head(input): def all2all_seq2head(input):
""" """
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。 将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
...@@ -42,6 +44,7 @@ def all2all_seq2head(input): ...@@ -42,6 +44,7 @@ def all2all_seq2head(input):
return output # 返回转换后的输出张量 return output # 返回转换后的输出张量
@dynamo.disable
def all2all_head2seq(input): def all2all_head2seq(input):
""" """
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。 将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
......
...@@ -2,39 +2,48 @@ import torch ...@@ -2,39 +2,48 @@ import torch
class WeightAsyncStreamManager(object): class WeightAsyncStreamManager(object):
def __init__(self): def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
self.active_weights = [None for _ in range(2)] self.active_weights = [None for _ in range(3)]
self.active_weights = [None for _ in range(2)] self.active_weights = [None for _ in range(3)]
self.compute_stream = torch.cuda.Stream(priority=-1) self.compute_stream = torch.cuda.Stream(priority=-1)
self.load_stream = torch.cuda.Stream(priority=0) self.cpu_load_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.offload_block_num = offload_ratio * blocks_num
self.phases_num = phases_num
self.offload_phases_num = blocks_num * phases_num * offload_ratio
def prefetch_weights(self, block_idx, blocks_weights): def prefetch_weights(self, block_idx, blocks_weights):
with torch.cuda.stream(self.load_stream): with torch.cuda.stream(self.cuda_load_stream):
if self.active_weights[1] is not None: self.active_weights[2] = blocks_weights[block_idx]
self.active_weights[1].to_cpu_async() self.active_weights[2].to_cuda_async()
new_weights = blocks_weights[block_idx] with torch.cuda.stream(self.cpu_load_stream):
new_weights.to_cuda_async() if block_idx < self.offload_block_num:
self.active_weights[1] = new_weights if self.active_weights[1] is not None:
self.active_weights[1].to_cpu_async()
def swap_weights(self): def swap_weights(self):
self.compute_stream.synchronize() self.compute_stream.synchronize()
self.load_stream.synchronize() self.cpu_load_stream.synchronize()
self.cuda_load_stream.synchronize()
self.active_weights[0], self.active_weights[1] = ( self.active_weights[0], self.active_weights[1] = (
self.active_weights[1], self.active_weights[2],
self.active_weights[0], self.active_weights[0],
) )
def prefetch_phase(self, block_idx, phase_idx, blocks): def prefetch_phase(self, block_idx, phase_idx, blocks):
with torch.cuda.stream(self.load_stream): with torch.cuda.stream(self.cuda_load_stream):
if self.active_weights[1] is not None:
_, old_phase = self.active_weights[1]
old_phase.to_cpu_async()
new_phase = blocks[block_idx].compute_phases[phase_idx] new_phase = blocks[block_idx].compute_phases[phase_idx]
new_phase.to_cuda_async() new_phase.to_cuda_async()
self.active_weights[1] = (phase_idx, new_phase) self.active_weights[2] = (phase_idx, blocks[block_idx].compute_phases[phase_idx])
with torch.cuda.stream(self.cpu_load_stream):
if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
if self.active_weights[1] is not None:
_, old_phase = self.active_weights[1]
old_phase.to_cpu_async()
def swap_phases(self): def swap_phases(self):
self.compute_stream.synchronize() self.compute_stream.synchronize()
self.load_stream.synchronize() self.cpu_load_stream.synchronize()
self.active_weights[0], self.active_weights[1] = self.active_weights[1], self.active_weights[0] self.cuda_load_stream.synchronize()
self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0]
...@@ -8,7 +8,6 @@ import torch.nn as nn ...@@ -8,7 +8,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
__all__ = [ __all__ = [
...@@ -471,7 +470,6 @@ def umt5_xxl(**kwargs): ...@@ -471,7 +470,6 @@ def umt5_xxl(**kwargs):
class T5EncoderModel: class T5EncoderModel:
@peak_memory_decorator
def __init__( def __init__(
self, self,
text_len, text_len,
......
...@@ -10,7 +10,6 @@ import torchvision.transforms as T ...@@ -10,7 +10,6 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer from lightx2v.models.input_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
from .xlm_roberta import XLMRoberta from .xlm_roberta import XLMRoberta
...@@ -429,7 +428,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -429,7 +428,6 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
@peak_memory_decorator
def __init__(self, dtype, device, checkpoint_path, tokenizer_path): def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
......
...@@ -16,8 +16,12 @@ class HunyuanTransformerInfer: ...@@ -16,8 +16,12 @@ class HunyuanTransformerInfer:
self.mlp_hidden_dim = 12288 self.mlp_hidden_dim = 12288
self.parallel_attention = None self.parallel_attention = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.double_weights_stream_mgr = WeightAsyncStreamManager() if "offload_ratio" in self.config:
self.single_weights_stream_mgr = WeightAsyncStreamManager() offload_ratio = self.config["offload_ratio"]
else:
offload_ratio = 1
self.double_weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.double_blocks_num, offload_ratio=offload_ratio)
self.single_weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.single_blocks_num, offload_ratio=offload_ratio)
self.infer_func = self._infer_with_offload self.infer_func = self._infer_with_offload
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
......
...@@ -10,17 +10,22 @@ class WanTransformerInfer: ...@@ -10,17 +10,22 @@ class WanTransformerInfer:
self.task = config["task"] self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.phases_num = 3
self.num_heads = config["num_heads"] self.num_heads = config["num_heads"]
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
else:
offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block") offload_granularity = self.config.get("offload_granularity", "block")
self.weights_stream_mgr = WeightAsyncStreamManager()
if offload_granularity == "block": if offload_granularity == "block":
self.infer_func = self._infer_with_offload self.infer_func = self._infer_with_offload
elif offload_granularity == "phase": elif offload_granularity == "phase":
self.infer_func = self._infer_with_phases_offload self.infer_func = self._infer_with_phases_offload
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.blocks_num, offload_ratio=offload_ratio, phases_num=self.phases_num)
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
...@@ -46,6 +51,9 @@ class WanTransformerInfer: ...@@ -46,6 +51,9 @@ class WanTransformerInfer:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0] self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda() self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block( x = self.infer_block(
self.weights_stream_mgr.active_weights[0], self.weights_stream_mgr.active_weights[0],
...@@ -58,8 +66,6 @@ class WanTransformerInfer: ...@@ -58,8 +66,6 @@ class WanTransformerInfer:
context, context,
) )
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
return x return x
...@@ -75,7 +81,7 @@ class WanTransformerInfer: ...@@ -75,7 +81,7 @@ class WanTransformerInfer:
elif embed0.dim() == 2: elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.blocks[block_idx].modulation.tensor + embed0).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.blocks[block_idx].modulation.tensor + embed0).chunk(6, dim=1)
for phase_idx in range(3): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx] phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase.to_cuda() phase.to_cuda()
......
...@@ -3,11 +3,9 @@ import torch ...@@ -3,11 +3,9 @@ import torch
from safetensors import safe_open from safetensors import safe_open
from loguru import logger from loguru import logger
import gc import gc
from lightx2v.utils.memory_profiler import peak_memory_decorator
class WanLoraWrapper: class WanLoraWrapper:
@peak_memory_decorator
def __init__(self, wan_model): def __init__(self, wan_model):
self.model = wan_model self.model = wan_model
self.lora_metadata = {} self.lora_metadata = {}
......
...@@ -20,7 +20,6 @@ from safetensors import safe_open ...@@ -20,7 +20,6 @@ from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
...@@ -29,7 +28,6 @@ class WanModel: ...@@ -29,7 +28,6 @@ class WanModel:
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
@peak_memory_decorator
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
......
...@@ -5,7 +5,6 @@ from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext ...@@ -5,7 +5,6 @@ from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.prompt_enhancer import PromptEnhancer from lightx2v.utils.prompt_enhancer import PromptEnhancer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
...@@ -46,7 +45,6 @@ class DefaultRunner: ...@@ -46,7 +45,6 @@ class DefaultRunner:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@peak_memory_decorator
def run(self): def run(self):
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
...@@ -76,7 +74,6 @@ class DefaultRunner: ...@@ -76,7 +74,6 @@ class DefaultRunner:
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ProfilingContext("Run VAE") @ProfilingContext("Run VAE")
@peak_memory_decorator
def run_vae(self, latents, generator): def run_vae(self, latents, generator):
images = self.vae_model.decode(latents, generator=generator, config=self.config) images = self.vae_model.decode(latents, generator=generator, config=self.config)
return images return images
......
...@@ -17,7 +17,6 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper ...@@ -17,7 +17,6 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
import torch.distributed as dist import torch.distributed as dist
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
...@@ -99,7 +98,6 @@ class WanRunner(DefaultRunner): ...@@ -99,7 +98,6 @@ class WanRunner(DefaultRunner):
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
@peak_memory_decorator
def run_text_encoder(self, text, text_encoders, config, image_encoder_output): def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
text_encoder_output = {} text_encoder_output = {}
n_prompt = config.get("negative_prompt", "") n_prompt = config.get("negative_prompt", "")
...@@ -109,7 +107,6 @@ class WanRunner(DefaultRunner): ...@@ -109,7 +107,6 @@ class WanRunner(DefaultRunner):
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
return text_encoder_output return text_encoder_output
@peak_memory_decorator
def run_image_encoder(self, config, image_encoder, vae_model): def run_image_encoder(self, config, image_encoder, vae_model):
if self.config.get("tiny_vae", False): if self.config.get("tiny_vae", False):
clip_image_encoder, vae_image_encoder = image_encoder[0], image_encoder[1] clip_image_encoder, vae_image_encoder = image_encoder[0], image_encoder[1]
......
...@@ -7,7 +7,6 @@ import torch.nn as nn ...@@ -7,7 +7,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
from einops import rearrange from einops import rearrange
from lightx2v.utils.memory_profiler import peak_memory_decorator
from loguru import logger from loguru import logger
__all__ = [ __all__ = [
...@@ -788,7 +787,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): ...@@ -788,7 +787,6 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE: class WanVAE:
@peak_memory_decorator
def __init__( def __init__(
self, self,
z_dim=16, z_dim=16,
......
...@@ -8,14 +8,25 @@ from loguru import logger ...@@ -8,14 +8,25 @@ from loguru import logger
class _ProfilingContext(ContextDecorator): class _ProfilingContext(ContextDecorator):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.rank_info = ""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
self.rank_info = f"Rank {rank} - "
def __enter__(self): def __enter__(self):
torch.cuda.synchronize() torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.start_time = time.perf_counter() self.start_time = time.perf_counter()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize() torch.cuda.synchronize()
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB")
else:
logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.")
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds") logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment