Commit 3e215bad authored by gushiqiao's avatar gushiqiao
Browse files

Support bf16/fp16 inference and mixed-precision inference with fp32 for some layers

parent e684202c
......@@ -33,7 +33,7 @@ class WanLoraWrapper:
use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
return tensor_dict
......
import glob
import json
import os
import torch
......@@ -103,20 +101,20 @@ class WanModel:
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path, use_bf16, skip_bf16):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f:
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self, use_bf16, skip_bf16):
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
def _load_quant_ckpt(self, use_bf16, skip_bf16):
def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
ckpt_path = self.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}")
......@@ -137,8 +135,8 @@ class WanModel:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16):
weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else:
......@@ -146,7 +144,7 @@ class WanModel:
return weight_dict
def _load_quant_split_ckpt(self, use_bf16, skip_bf16):
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer):
lazy_load_model_path = self.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict = {}
......@@ -155,8 +153,8 @@ class WanModel:
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else:
......@@ -173,9 +171,9 @@ class WanModel:
pass
def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16"
unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
# Some layers run with float32 to achieve high accuracy
skip_bf16 = {
sensitive_layer = {
"norm",
"embedding",
"modulation",
......@@ -185,14 +183,12 @@ class WanModel:
}
if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
elif self.config.get("use_gguf", False):
self.original_weight_dict = self._load_gguf_ckpt()
self.original_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
self.original_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else:
self.original_weight_dict = self._load_quant_split_ckpt(use_bf16, skip_bf16)
self.original_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
else:
self.original_weight_dict = weight_dict
# init weights
......@@ -300,11 +296,11 @@ class WanModel:
class Wan22MoeModel(WanModel):
def _load_ckpt(self, use_bf16, skip_bf16):
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
......
......@@ -12,6 +12,7 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid
......@@ -62,7 +63,7 @@ class HunyuanRunner(DefaultRunner):
text_state, attention_mask = encoder.infer(text, img, self.config)
else:
text_state, attention_mask = encoder.infer(text, self.config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=GET_DTYPE())
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
return text_encoder_output
......
......@@ -20,6 +20,7 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
......@@ -259,7 +260,7 @@ class VideoGenerator:
return None
device = torch.device("cuda")
dtype = torch.bfloat16
dtype = GET_DTYPE()
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
......@@ -312,7 +313,7 @@ class VideoGenerator:
# Prepare previous latents - ALWAYS needed, even for first segment
device = torch.device("cuda")
dtype = torch.bfloat16
dtype = GET_DTYPE()
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
max_num_frames = self.config.target_video_length
......@@ -425,7 +426,7 @@ class WanAudioRunner(WanRunner): # type:ignore
else:
device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, weight=1.0, cpu_offload=cpu_offload)
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload)
return self._audio_adapter_pipe
......@@ -655,13 +656,13 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
# clip encoder
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16) if self.config.get("use_image_encoder", True) else None
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
# vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
vae_encoder_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encoder_out, list):
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(torch.bfloat16)
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE())
return vae_encoder_out, clip_encoder_out
......
......@@ -8,6 +8,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
......@@ -65,13 +66,13 @@ class WanCausVidRunner(WanRunner):
)
def run(self):
self.model.transformer_infer._init_kv_cache(dtype=torch.bfloat16, device="cuda")
self.model.transformer_infer._init_crossattn_cache(dtype=torch.bfloat16, device="cuda")
self.model.transformer_infer._init_kv_cache(dtype=GET_DTYPE(), device="cuda")
self.model.transformer_infer._init_crossattn_cache(dtype=GET_DTYPE(), device="cuda")
output_latents = torch.zeros(
(self.model.config.target_shape[0], self.num_frames + (self.num_fragments - 1) * (self.num_frames - self.num_frame_per_block), *self.model.config.target_shape[2:]),
device="cuda",
dtype=torch.bfloat16,
dtype=GET_DTYPE(),
)
start_block_idx = 0
......
......@@ -24,6 +24,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
......@@ -207,7 +208,7 @@ class WanRunner(DefaultRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16)
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder
torch.cuda.empty_cache()
......@@ -271,7 +272,7 @@ class WanRunner(DefaultRunner):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(torch.bfloat16)
vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
return vae_encoder_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
......
......@@ -9,6 +9,7 @@ from loguru import logger
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
......@@ -37,7 +38,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
config.lat_w = lat_w
vae_encoder_out = vae_model.encode([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1).cuda()], config)[0]
vae_encoder_out = vae_encoder_out.to(torch.bfloat16)
vae_encoder_out = vae_encoder_out.to(GET_DTYPE())
return vae_encoder_out
def set_target_shape(self):
......
......@@ -269,5 +269,3 @@ class CogvideoxXDPMScheduler(BaseScheduler):
x_advanced = mult[0] * self.latents - mult[1] * denoised_d + mult_noise * noise
self.latents = x_advanced
self.old_pred_original_sample = pred_original_sample
self.latents = self.latents.to(torch.bfloat16)
......@@ -5,6 +5,7 @@ import torch
from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import *
def _to_tuple(x, dim=2):
......@@ -247,12 +248,12 @@ class HunyuanScheduler(BaseScheduler):
def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output
self.prepare_latents(shape=self.config.target_shape, dtype=torch.float16, image_encoder_output=image_encoder_output)
self.prepare_latents(shape=self.config.target_shape, dtype=torch.float32, image_encoder_output=image_encoder_output)
self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.config.target_video_length, height=self.config.target_height, width=self.config.target_width)
def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=GET_DTYPE(), device=torch.device("cuda")) * 1000.0
def step_post(self):
if self.config.task == "t2v":
......@@ -316,8 +317,8 @@ class HunyuanScheduler(BaseScheduler):
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_cos = self.freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
else:
L_test = rope_sizes[0] # Latent frames
......@@ -359,5 +360,5 @@ class HunyuanScheduler(BaseScheduler):
theta_rescale_factor=1,
)
self.freqs_cos = freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_cos = freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
self.freqs_sin = freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
import torch
from lightx2v.utils.envs import *
......@@ -15,8 +13,8 @@ class BaseScheduler:
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
if GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(GET_DTYPE())
def clear(self):
pass
import gc
import math
import warnings
import numpy as np
import torch
from torch import Tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import *
......@@ -34,8 +32,8 @@ class EulerSchedulerTimestepFix(BaseScheduler):
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
if GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(GET_DTYPE())
def prepare(self, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
......
......@@ -5,6 +5,7 @@ import numpy as np
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
class WanSkyreelsV2DFScheduler(WanScheduler):
......@@ -132,7 +133,8 @@ class WanSkyreelsV2DFScheduler(WanScheduler):
def step_pre(self, step_index):
self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16)
if GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(GET_DTYPE())
valid_interval_start, valid_interval_end = self.valid_interval[step_index]
timestep = self.step_matrix[step_index][None, valid_interval_start:valid_interval_end].clone()
......
......@@ -156,7 +156,7 @@ class UpsampleCausal3D(nn.Module):
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
if dtype in [torch.bfloat16, torch.float16]:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
......@@ -185,7 +185,7 @@ class UpsampleCausal3D(nn.Module):
hidden_states = first_h
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
if dtype in [torch.bfloat16, torch.float16]:
hidden_states = hidden_states.to(dtype)
if self.use_conv:
......
......@@ -6,6 +6,7 @@ from diffusers.video_processor import VideoProcessor # type: ignore
from safetensors import safe_open # type: ignore
from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX
from lightx2v.utils.envs import *
class CogvideoxVAE:
......@@ -15,7 +16,7 @@ class CogvideoxVAE:
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self, model_path):
......@@ -39,7 +40,7 @@ class CogvideoxVAE:
self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4
self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7
self.model.load_state_dict(vae_ckpt)
self.model.to(torch.bfloat16).to(torch.device("cuda"))
self.model.to(GET_DTYPE()).to(torch.device("cuda"))
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@torch.no_grad()
......
import os
from functools import lru_cache
import torch
DTYPE_MAP = {
"BF16": torch.bfloat16,
"FP16": torch.float16,
"FP32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"torch.bfloat16": torch.bfloat16,
"torch.float16": torch.float16,
"torch.float32": torch.float32,
}
@lru_cache(maxsize=None)
def CHECK_ENABLE_PROFILING_DEBUG():
......@@ -22,5 +36,14 @@ def GET_RUNNING_FLAG():
@lru_cache(maxsize=None)
def GET_DTYPE():
RUNNING_FLAG = os.getenv("DTYPE")
return RUNNING_FLAG
RUNNING_FLAG = os.getenv("DTYPE", "BF16")
assert RUNNING_FLAG in ["BF16", "FP16"]
return DTYPE_MAP[RUNNING_FLAG]
@lru_cache(maxsize=None)
def GET_SENSITIVE_DTYPE():
RUNNING_FLAG = os.getenv("SENSITIVE_LAYER_DTYPE", None)
if RUNNING_FLAG is None:
return GET_DTYPE()
return DTYPE_MAP[RUNNING_FLAG]
......@@ -25,7 +25,8 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......
......@@ -25,7 +25,8 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......
......@@ -25,7 +25,8 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment