Commit 3e215bad authored by gushiqiao's avatar gushiqiao
Browse files

Support bf16/fp16 inference and mixed-precision inference with fp32 for some layers

parent e684202c
...@@ -33,7 +33,7 @@ class WanLoraWrapper: ...@@ -33,7 +33,7 @@ class WanLoraWrapper:
use_bfloat16 = self.model.config.get("use_bfloat16", True) use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
if use_bfloat16: if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16) for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
else: else:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()} tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
return tensor_dict return tensor_dict
......
import glob
import json
import os import os
import torch import torch
...@@ -103,20 +101,20 @@ class WanModel: ...@@ -103,20 +101,20 @@ class WanModel:
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path, use_bf16, skip_bf16): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
def _load_quant_ckpt(self, use_bf16, skip_bf16): def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
ckpt_path = self.dit_quantized_ckpt ckpt_path = self.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}") logger.info(f"Loading quant dit model from {ckpt_path}")
...@@ -137,8 +135,8 @@ class WanModel: ...@@ -137,8 +135,8 @@ class WanModel:
logger.info(f"Loading weights from {safetensor_path}") logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype == torch.float: if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16): if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else: else:
...@@ -146,7 +144,7 @@ class WanModel: ...@@ -146,7 +144,7 @@ class WanModel:
return weight_dict return weight_dict
def _load_quant_split_ckpt(self, use_bf16, skip_bf16): def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer):
lazy_load_model_path = self.dit_quantized_ckpt lazy_load_model_path = self.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}") logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict = {} pre_post_weight_dict = {}
...@@ -155,8 +153,8 @@ class WanModel: ...@@ -155,8 +153,8 @@ class WanModel:
with safe_open(safetensor_path, framework="pt", device="cpu") as f: with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype == torch.float: if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16): if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
else: else:
...@@ -173,9 +171,9 @@ class WanModel: ...@@ -173,9 +171,9 @@ class WanModel:
pass pass
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16" unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
# Some layers run with float32 to achieve high accuracy # Some layers run with float32 to achieve high accuracy
skip_bf16 = { sensitive_layer = {
"norm", "norm",
"embedding", "embedding",
"modulation", "modulation",
...@@ -185,14 +183,12 @@ class WanModel: ...@@ -185,14 +183,12 @@ class WanModel:
} }
if weight_dict is None: if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16) self.original_weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
elif self.config.get("use_gguf", False):
self.original_weight_dict = self._load_gguf_ckpt()
else: else:
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16) self.original_weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else: else:
self.original_weight_dict = self._load_quant_split_ckpt(use_bf16, skip_bf16) self.original_weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # init weights
...@@ -300,11 +296,11 @@ class WanModel: ...@@ -300,11 +296,11 @@ class WanModel:
class Wan22MoeModel(WanModel): class Wan22MoeModel(WanModel):
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
......
...@@ -12,6 +12,7 @@ from lightx2v.models.runners.default_runner import DefaultRunner ...@@ -12,6 +12,7 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid from lightx2v.utils.utils import save_videos_grid
...@@ -62,7 +63,7 @@ class HunyuanRunner(DefaultRunner): ...@@ -62,7 +63,7 @@ class HunyuanRunner(DefaultRunner):
text_state, attention_mask = encoder.infer(text, img, self.config) text_state, attention_mask = encoder.infer(text, img, self.config)
else: else:
text_state, attention_mask = encoder.infer(text, self.config) text_state, attention_mask = encoder.infer(text, self.config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16) text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=GET_DTYPE())
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
return text_encoder_output return text_encoder_output
......
...@@ -20,6 +20,7 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi ...@@ -20,6 +20,7 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
...@@ -259,7 +260,7 @@ class VideoGenerator: ...@@ -259,7 +260,7 @@ class VideoGenerator:
return None return None
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 dtype = GET_DTYPE()
vae_dtype = torch.float vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
...@@ -312,7 +313,7 @@ class VideoGenerator: ...@@ -312,7 +313,7 @@ class VideoGenerator:
# Prepare previous latents - ALWAYS needed, even for first segment # Prepare previous latents - ALWAYS needed, even for first segment
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 dtype = GET_DTYPE()
vae_dtype = torch.float vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
max_num_frames = self.config.target_video_length max_num_frames = self.config.target_video_length
...@@ -425,7 +426,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -425,7 +426,7 @@ class WanAudioRunner(WanRunner): # type:ignore
else: else:
device = torch.device("cuda") device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder" audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, weight=1.0, cpu_offload=cpu_offload) self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload)
return self._audio_adapter_pipe return self._audio_adapter_pipe
...@@ -655,13 +656,13 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -655,13 +656,13 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic") cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
# clip encoder # clip encoder
clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(torch.bfloat16) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.image_encoder.visual([cond_frms], self.config).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
# vae encode # vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
vae_encoder_out = vae_model.encode(cond_frms.to(torch.float), config) vae_encoder_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encoder_out, list): if isinstance(vae_encoder_out, list):
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(torch.bfloat16) vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE())
return vae_encoder_out, clip_encoder_out return vae_encoder_out, clip_encoder_out
......
...@@ -8,6 +8,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper ...@@ -8,6 +8,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -65,13 +66,13 @@ class WanCausVidRunner(WanRunner): ...@@ -65,13 +66,13 @@ class WanCausVidRunner(WanRunner):
) )
def run(self): def run(self):
self.model.transformer_infer._init_kv_cache(dtype=torch.bfloat16, device="cuda") self.model.transformer_infer._init_kv_cache(dtype=GET_DTYPE(), device="cuda")
self.model.transformer_infer._init_crossattn_cache(dtype=torch.bfloat16, device="cuda") self.model.transformer_infer._init_crossattn_cache(dtype=GET_DTYPE(), device="cuda")
output_latents = torch.zeros( output_latents = torch.zeros(
(self.model.config.target_shape[0], self.num_frames + (self.num_fragments - 1) * (self.num_frames - self.num_frame_per_block), *self.model.config.target_shape[2:]), (self.model.config.target_shape[0], self.num_frames + (self.num_fragments - 1) * (self.num_frames - self.num_frame_per_block), *self.model.config.target_shape[2:]),
device="cuda", device="cuda",
dtype=torch.bfloat16, dtype=GET_DTYPE(),
) )
start_block_idx = 0 start_block_idx = 0
......
...@@ -24,6 +24,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler ...@@ -24,6 +24,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny from lightx2v.models.video_encoders.hf.wan.vae_tiny import WanVAE_tiny
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video from lightx2v.utils.utils import best_output_size, cache_video
...@@ -207,7 +208,7 @@ class WanRunner(DefaultRunner): ...@@ -207,7 +208,7 @@ class WanRunner(DefaultRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]], self.config).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder del self.image_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -271,7 +272,7 @@ class WanRunner(DefaultRunner): ...@@ -271,7 +272,7 @@ class WanRunner(DefaultRunner):
del self.vae_encoder del self.vae_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(torch.bfloat16) vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
return vae_encoder_out return vae_encoder_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img): def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
......
...@@ -9,6 +9,7 @@ from loguru import logger ...@@ -9,6 +9,7 @@ from loguru import logger
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -37,7 +38,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I ...@@ -37,7 +38,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
config.lat_w = lat_w config.lat_w = lat_w
vae_encoder_out = vae_model.encode([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1).cuda()], config)[0] vae_encoder_out = vae_model.encode([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1).cuda()], config)[0]
vae_encoder_out = vae_encoder_out.to(torch.bfloat16) vae_encoder_out = vae_encoder_out.to(GET_DTYPE())
return vae_encoder_out return vae_encoder_out
def set_target_shape(self): def set_target_shape(self):
......
...@@ -269,5 +269,3 @@ class CogvideoxXDPMScheduler(BaseScheduler): ...@@ -269,5 +269,3 @@ class CogvideoxXDPMScheduler(BaseScheduler):
x_advanced = mult[0] * self.latents - mult[1] * denoised_d + mult_noise * noise x_advanced = mult[0] * self.latents - mult[1] * denoised_d + mult_noise * noise
self.latents = x_advanced self.latents = x_advanced
self.old_pred_original_sample = pred_original_sample self.old_pred_original_sample = pred_original_sample
self.latents = self.latents.to(torch.bfloat16)
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import *
def _to_tuple(x, dim=2): def _to_tuple(x, dim=2):
...@@ -247,12 +248,12 @@ class HunyuanScheduler(BaseScheduler): ...@@ -247,12 +248,12 @@ class HunyuanScheduler(BaseScheduler):
def prepare(self, image_encoder_output): def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output self.image_encoder_output = image_encoder_output
self.prepare_latents(shape=self.config.target_shape, dtype=torch.float16, image_encoder_output=image_encoder_output) self.prepare_latents(shape=self.config.target_shape, dtype=torch.float32, image_encoder_output=image_encoder_output)
self.prepare_guidance() self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.config.target_video_length, height=self.config.target_height, width=self.config.target_width) self.prepare_rotary_pos_embedding(video_length=self.config.target_video_length, height=self.config.target_height, width=self.config.target_width)
def prepare_guidance(self): def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0 self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=GET_DTYPE(), device=torch.device("cuda")) * 1000.0
def step_post(self): def step_post(self):
if self.config.task == "t2v": if self.config.task == "t2v":
...@@ -316,8 +317,8 @@ class HunyuanScheduler(BaseScheduler): ...@@ -316,8 +317,8 @@ class HunyuanScheduler(BaseScheduler):
use_real=True, use_real=True,
theta_rescale_factor=1, theta_rescale_factor=1,
) )
self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda")) self.freqs_cos = self.freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda")) self.freqs_sin = self.freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
else: else:
L_test = rope_sizes[0] # Latent frames L_test = rope_sizes[0] # Latent frames
...@@ -359,5 +360,5 @@ class HunyuanScheduler(BaseScheduler): ...@@ -359,5 +360,5 @@ class HunyuanScheduler(BaseScheduler):
theta_rescale_factor=1, theta_rescale_factor=1,
) )
self.freqs_cos = freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda")) self.freqs_cos = freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
self.freqs_sin = freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda")) self.freqs_sin = freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -15,8 +13,8 @@ class BaseScheduler: ...@@ -15,8 +13,8 @@ class BaseScheduler:
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
if GET_DTYPE() == "BF16": if GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(dtype=torch.bfloat16) self.latents = self.latents.to(GET_DTYPE())
def clear(self): def clear(self):
pass pass
import gc import gc
import math import math
import warnings
import numpy as np import numpy as np
import torch import torch
from torch import Tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -34,8 +32,8 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -34,8 +32,8 @@ class EulerSchedulerTimestepFix(BaseScheduler):
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
if GET_DTYPE() == "BF16": if GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(dtype=torch.bfloat16) self.latents = self.latents.to(GET_DTYPE())
def prepare(self, image_encoder_output=None): def prepare(self, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import torch import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
class WanSkyreelsV2DFScheduler(WanScheduler): class WanSkyreelsV2DFScheduler(WanScheduler):
...@@ -132,7 +133,8 @@ class WanSkyreelsV2DFScheduler(WanScheduler): ...@@ -132,7 +133,8 @@ class WanSkyreelsV2DFScheduler(WanScheduler):
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16) if GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(GET_DTYPE())
valid_interval_start, valid_interval_end = self.valid_interval[step_index] valid_interval_start, valid_interval_end = self.valid_interval[step_index]
timestep = self.step_matrix[step_index][None, valid_interval_start:valid_interval_end].clone() timestep = self.step_matrix[step_index][None, valid_interval_start:valid_interval_end].clone()
......
...@@ -156,7 +156,7 @@ class UpsampleCausal3D(nn.Module): ...@@ -156,7 +156,7 @@ class UpsampleCausal3D(nn.Module):
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype dtype = hidden_states.dtype
if dtype == torch.bfloat16: if dtype in [torch.bfloat16, torch.float16]:
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
...@@ -185,7 +185,7 @@ class UpsampleCausal3D(nn.Module): ...@@ -185,7 +185,7 @@ class UpsampleCausal3D(nn.Module):
hidden_states = first_h hidden_states = first_h
# If the input is bfloat16, we cast back to bfloat16 # If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16: if dtype in [torch.bfloat16, torch.float16]:
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
if self.use_conv: if self.use_conv:
......
...@@ -6,6 +6,7 @@ from diffusers.video_processor import VideoProcessor # type: ignore ...@@ -6,6 +6,7 @@ from diffusers.video_processor import VideoProcessor # type: ignore
from safetensors import safe_open # type: ignore from safetensors import safe_open # type: ignore
from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX
from lightx2v.utils.envs import *
class CogvideoxVAE: class CogvideoxVAE:
...@@ -15,7 +16,7 @@ class CogvideoxVAE: ...@@ -15,7 +16,7 @@ class CogvideoxVAE:
def _load_safetensor_to_dict(self, file_path): def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).cuda() for key in f.keys()}
return tensor_dict return tensor_dict
def _load_ckpt(self, model_path): def _load_ckpt(self, model_path):
...@@ -39,7 +40,7 @@ class CogvideoxVAE: ...@@ -39,7 +40,7 @@ class CogvideoxVAE:
self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4 self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4
self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7 self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7
self.model.load_state_dict(vae_ckpt) self.model.load_state_dict(vae_ckpt)
self.model.to(torch.bfloat16).to(torch.device("cuda")) self.model.to(GET_DTYPE()).to(torch.device("cuda"))
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@torch.no_grad() @torch.no_grad()
......
import os import os
from functools import lru_cache from functools import lru_cache
import torch
DTYPE_MAP = {
"BF16": torch.bfloat16,
"FP16": torch.float16,
"FP32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"torch.bfloat16": torch.bfloat16,
"torch.float16": torch.float16,
"torch.float32": torch.float32,
}
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def CHECK_ENABLE_PROFILING_DEBUG(): def CHECK_ENABLE_PROFILING_DEBUG():
...@@ -22,5 +36,14 @@ def GET_RUNNING_FLAG(): ...@@ -22,5 +36,14 @@ def GET_RUNNING_FLAG():
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def GET_DTYPE(): def GET_DTYPE():
RUNNING_FLAG = os.getenv("DTYPE") RUNNING_FLAG = os.getenv("DTYPE", "BF16")
return RUNNING_FLAG assert RUNNING_FLAG in ["BF16", "FP16"]
return DTYPE_MAP[RUNNING_FLAG]
@lru_cache(maxsize=None)
def GET_SENSITIVE_DTYPE():
RUNNING_FLAG = os.getenv("SENSITIVE_LAYER_DTYPE", None)
if RUNNING_FLAG is None:
return GET_DTYPE()
return DTYPE_MAP[RUNNING_FLAG]
...@@ -25,7 +25,8 @@ fi ...@@ -25,7 +25,8 @@ fi
export TOKENIZERS_PARALLELISM=false export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
......
...@@ -25,7 +25,8 @@ fi ...@@ -25,7 +25,8 @@ fi
export TOKENIZERS_PARALLELISM=false export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
......
...@@ -25,7 +25,8 @@ fi ...@@ -25,7 +25,8 @@ fi
export TOKENIZERS_PARALLELISM=false export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment