Commit 8ba6e3b4 authored by gushiqiao's avatar gushiqiao
Browse files

Fixed the accuracy fluctuation bug

parent 793ec1db
......@@ -37,6 +37,7 @@ class WanModel:
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
assert GET_DTYPE() == "BF16"
self.device = device
self._init_infer_class()
......@@ -63,13 +64,10 @@ class WanModel:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path):
use_bfloat16 = self.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).pin_memory().to(torch.bfloat16).to(self.device) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key).pin_memory().to(self.device) for key in f.keys()}
return tensor_dict
use_bf16 = GET_DTYPE() == "BF16"
skip_bf16 = {"norm", "embedding", "modulation", "time"}
return {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self):
safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
......@@ -119,7 +117,7 @@ class WanModel:
pre_post_weight_dict, transformer_weight_dict = {}, {}
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory()
if pre_post_weight_dict[k].dtype == torch.float:
......@@ -154,7 +152,6 @@ class WanModel:
) = self._load_quant_split_ckpt()
else:
self.original_weight_dict = weight_dict
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, TENSOR_REGISTER
from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER,
TENSOR_REGISTER,
LN_WEIGHT_REGISTER,
)
from lightx2v.common.modules.weight_module import WeightModule
......@@ -6,5 +10,9 @@ class WanPostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.register_parameter(
"norm",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
CONV3D_WEIGHT_REGISTER,
)
from lightx2v.common.modules.weight_module import WeightModule
......@@ -10,15 +14,45 @@ class WanPreWeights(WeightModule):
self.patch_size = (1, 2, 2)
self.config = config
self.add_module("patch_embedding", CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size))
self.add_module("text_embedding_0", MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias"))
self.add_module("text_embedding_2", MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias"))
self.add_module("time_embedding_0", MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"))
self.add_module("time_embedding_2", MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"))
self.add_module("time_projection_1", MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"))
self.add_module(
"patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size),
)
self.add_module(
"text_embedding_0",
MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias"),
)
self.add_module(
"text_embedding_2",
MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias"),
)
self.add_module(
"time_embedding_0",
MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"),
)
self.add_module(
"time_embedding_2",
MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"),
)
self.add_module(
"time_projection_1",
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
)
if config.task == "i2v":
self.add_module("proj_0", LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5))
self.add_module("proj_1", MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"))
self.add_module("proj_3", MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"))
self.add_module("proj_4", LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5))
self.add_module(
"proj_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
)
self.add_module(
"proj_1",
MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"),
)
self.add_module(
"proj_3",
MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"),
)
self.add_module(
"proj_4",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias"),
)
......@@ -49,9 +49,30 @@ class WanTransformerAttentionBlock(WeightModule):
self.compute_phases = WeightModuleList(
[
WanSelfAttention(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
WanCrossAttention(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
WanFFN(block_index, task, mm_type, config, self.lazy_load, self.lazy_load_file),
WanSelfAttention(
block_index,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanCrossAttention(
block_index,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanFFN(
block_index,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
]
)
......@@ -71,6 +92,11 @@ class WanSelfAttention(WeightModule):
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.register_parameter(
"norm1",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module(
"self_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
......@@ -169,7 +195,6 @@ class WanCrossAttention(WeightModule):
f"blocks.{self.block_index}.norm3.bias",
self.lazy_load,
self.lazy_load_file,
eps=1e-6,
),
)
self.add_module(
......@@ -267,6 +292,11 @@ class WanFFN(WeightModule):
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.register_parameter(
"norm2",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module(
"ffn_0",
MM_WEIGHT_REGISTER[self.mm_type](
......
......@@ -102,7 +102,12 @@ class WanRunner(DefaultRunner):
return vae_decoder
def load_vae(self):
return self.load_vae_encoder(), self.load_vae_decoder()
vae_encoder = self.load_vae_encoder()
if vae_encoder is None or self.config.get("tiny_vae", False):
vae_decoder = self.load_vae_decoder()
else:
vae_decoder = vae_encoder
return vae_encoder, vae_decoder
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
......
import torch
from lightx2v.utils.envs import *
class BaseScheduler:
......@@ -10,7 +11,8 @@ class BaseScheduler:
def step_pre(self, step_index):
self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16)
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
def clear(self):
pass
import os
import torch
from functools import lru_cache
......@@ -18,3 +19,9 @@ def CHECK_ENABLE_GRAPH_MODE():
def GET_RUNNING_FLAG():
RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer")
return RUNNING_FLAG
@lru_cache(maxsize=None)
def GET_DTYPE():
RUNNING_FLAG = os.getenv("DTYPE")
return RUNNING_FLAG
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -50,7 +50,7 @@ if __name__ == "__main__":
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", # It is best to set it to an absolute path.
}
......
......@@ -7,7 +7,7 @@ url = "http://localhost:8000/v1/local/video/generate"
message = {
"task_id": "test_task_001",
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_enhanced.mp4", # It is best to set it to an absolute path.
"use_prompt_enhancer": True,
......
......@@ -50,7 +50,7 @@ if __name__ == "__main__":
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "./assets/inputs/imgs/img_0.jpg",
"save_video_path": "./output_lightx2v_wan_i2v_t02.mp4", # It is best to set it to an absolute path.
}
......
......@@ -68,7 +68,7 @@ if __name__ == "__main__":
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t01.mp4", # It is best to set it to an absolute path.
},
......@@ -76,7 +76,7 @@ if __name__ == "__main__":
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", # It is best to set it to an absolute path.
},
......@@ -84,7 +84,7 @@ if __name__ == "__main__":
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t03.mp4", # It is best to set it to an absolute path.
},
......@@ -92,7 +92,7 @@ if __name__ == "__main__":
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t04.mp4", # It is best to set it to an absolute path.
},
......@@ -100,7 +100,7 @@ if __name__ == "__main__":
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t05.mp4", # It is best to set it to an absolute path.
},
......
......@@ -27,7 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
export PYTHONPATH=/mtc/wushuo/VideoGen/diffusers:$PYTHONPATH
python -m lightx2v.infer \
......
......@@ -24,7 +24,7 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......
......@@ -24,7 +24,7 @@ fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment