Commit 8ba6e3b4 authored by gushiqiao's avatar gushiqiao
Browse files

Fixed the accuracy fluctuation bug

parent 793ec1db
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
"target_height": 480, "target_height": 480,
"target_width": 832, "target_width": 832,
"attention_type": "flash_attn3", "attention_type": "flash_attn3",
"seed": 42, "seed": 442,
"sample_guide_scale": 5, "sample_guide_scale": 5,
"sample_shift": 5, "sample_shift": 3,
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": false "cpu_offload": false
} }
...@@ -28,7 +28,7 @@ message = { ...@@ -28,7 +28,7 @@ message = {
"task_id": generate_task_id(), "task_id": generate_task_id(),
"task_id_must_unique": True, "task_id_must_unique": True,
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "", "image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", "save_video_path": "./output_lightx2v_wan_t2v_t02.mp4",
} }
......
...@@ -28,7 +28,7 @@ message = { ...@@ -28,7 +28,7 @@ message = {
"task_id": generate_task_id(), "task_id": generate_task_id(),
"task_id_must_unique": True, "task_id_must_unique": True,
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "", "image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", "save_video_path": "./output_lightx2v_wan_t2v_t02.mp4",
} }
......
...@@ -657,7 +657,7 @@ curl -X 'POST' \ ...@@ -657,7 +657,7 @@ curl -X 'POST' \
"task_id_must_unique": false, "task_id_must_unique": false,
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline'\''s intricate details and the refreshing atmosphere of the seaside.", "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline'\''s intricate details and the refreshing atmosphere of the seaside.",
"use_prompt_enhancer": false, "use_prompt_enhancer": false,
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", "negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "/mnt/aigc/users/gaopeng1/ComfyUI-Lightx2vWrapper/lightx2v/assets/inputs/imgs/img_0.jpg", "image_path": "/mnt/aigc/users/gaopeng1/ComfyUI-Lightx2vWrapper/lightx2v/assets/inputs/imgs/img_0.jpg",
"num_fragments": 1, "num_fragments": 1,
"save_video_path": "/mnt/aigc/users/lijiaqi2/ComfyUI/custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v/save_results/img_0.mp4" "save_video_path": "/mnt/aigc/users/lijiaqi2/ComfyUI/custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v/save_results/img_0.mp4"
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -32,10 +32,14 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -32,10 +32,14 @@ class Conv3dWeight(Conv3dWeightTemplate):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups) super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name]
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
# if input_tensor.dtype == torch.float:
# input_tensor = input_tensor.to(torch.bfloat16)
input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
return input_tensor return input_tensor
...@@ -56,13 +60,3 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -56,13 +60,3 @@ class Conv3dWeight(Conv3dWeightTemplate):
if self.bias is not None: if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination return destination
@CONV3D_WEIGHT_REGISTER("Defaultt-Force-BF16")
class Conv3dWeightForceBF16(Conv3dWeight):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].to(torch.bfloat16).cuda()
self.bias = weight_dict[self.bias_name].to(torch.bfloat16).cuda() if self.bias_name is not None else None
...@@ -79,6 +79,7 @@ class MMWeight(MMWeightTemplate): ...@@ -79,6 +79,7 @@ class MMWeight(MMWeightTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
# if input_tensor.dtype != torch.float
shape = (input_tensor.shape[0], self.weight.shape[1]) shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype dtype = input_tensor.dtype
device = input_tensor.device device = input_tensor.device
......
import torch import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
from lightx2v.utils.envs import *
class LNWeightTemplate(metaclass=ABCMeta): class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name=None, bias_name=None, lazy_load=False, lazy_load_file=None, eps=1e-6):
self.weight_name = weight_name self.weight_name = weight_name
self.bias_name = bias_name self.bias_name = bias_name
self.eps = eps self.eps = eps
...@@ -12,23 +13,6 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -12,23 +13,6 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.config = {} self.config = {}
def load_from_disk(self):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
else:
self.weight = None
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
else:
self.bias = None
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
if self.weight_name is not None: if self.weight_name is not None:
...@@ -89,9 +73,35 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -89,9 +73,35 @@ class LNWeightTemplate(metaclass=ABCMeta):
@LN_WEIGHT_REGISTER("Default") @LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate): class LNWeight(LNWeightTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name=None, bias_name=None, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file, eps)
def load_from_disk(self):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
else:
self.weight = None
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
else:
self.bias = None
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps) if self.weight is not None and self.weight.dtype == torch.bfloat16:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
else:
input_tensor = torch.nn.functional.layer_norm(
input_tensor.float(),
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
).to(torch.bfloat16)
return input_tensor return input_tensor
import torch import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
from lightx2v.utils.envs import *
try: try:
import sgl_kernel import sgl_kernel
...@@ -16,12 +17,6 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -16,12 +17,6 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.config = {} self.config = {}
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
...@@ -56,9 +51,25 @@ class RMSWeight(RMSWeightTemplate): ...@@ -56,9 +51,25 @@ class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def load(self, weight_dict):
if not self.lazy_load:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps) if GET_DTYPE() == "BF16":
input_tensor = input_tensor * self.weight input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = (input_tensor * self.weight).to(torch.bfloat16)
return input_tensor return input_tensor
def state_dict(self, destination=None): def state_dict(self, destination=None):
...@@ -68,32 +79,36 @@ class RMSWeight(RMSWeightTemplate): ...@@ -68,32 +79,36 @@ class RMSWeight(RMSWeightTemplate):
return destination return destination
@RMS_WEIGHT_REGISTER("FP32")
class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor):
input_tensor = input_tensor.float()
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor.to(torch.bfloat16)
input_tensor = input_tensor * self.weight
return input_tensor
@RMS_WEIGHT_REGISTER("sgl-kernel") @RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight): class RMSWeightSgl(RMSWeight):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6): def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor): def load(self, weight_dict):
if sgl_kernel is None: if not self.lazy_load:
# sgl_kernel is not available, fallback to default implementation self.weight = weight_dict[self.weight_name]
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps) self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
input_tensor = input_tensor * self.weight
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
else: else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
def apply(self, input_tensor):
use_bf16 = GET_DTYPE() == "BF16"
if sgl_kernel is not None and use_bf16:
input_tensor = input_tensor.contiguous() input_tensor = input_tensor.contiguous()
orig_shape = input_tensor.shape orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, orig_shape[-1]) input_tensor = input_tensor.view(-1, orig_shape[-1])
input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape) input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape)
else:
# sgl_kernel is not available or dtype!=torch.bfloat16, fallback to default implementation
if use_bf16:
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(input_tensor)
input_tensor = (input_tensor * self.weight).type_as(input_tensor)
return input_tensor return input_tensor
import torch import torch
from lightx2v.utils.registry_factory import TENSOR_REGISTER from lightx2v.utils.registry_factory import TENSOR_REGISTER
from safetensors import safe_open from lightx2v.utils.envs import *
@TENSOR_REGISTER("Default") @TENSOR_REGISTER("Default")
...@@ -18,7 +18,7 @@ class DefaultTensor: ...@@ -18,7 +18,7 @@ class DefaultTensor:
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
self.tensor = weight_dict[self.tensor_name].to(torch.bfloat16) self.tensor = weight_dict[self.tensor_name]
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype) self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
def clear(self): def clear(self):
......
import os import os
import torch import torch
import time
import glob
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
...@@ -13,10 +11,7 @@ from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer ...@@ -13,10 +11,7 @@ from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import ( from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid, WanTransformerInferCausVid,
) )
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching from lightx2v.utils.envs import *
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
class WanCausVidModel(WanModel): class WanCausVidModel(WanModel):
...@@ -33,7 +28,7 @@ class WanCausVidModel(WanModel): ...@@ -33,7 +28,7 @@ class WanCausVidModel(WanModel):
self.transformer_infer_class = WanTransformerInferCausVid self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self): def _load_ckpt(self):
use_bfloat16 = self.config.get("use_bfloat16", True) use_bfloat16 = GET_DTYPE() == "BF16"
ckpt_path = os.path.join(self.model_path, "causal_model.pt") ckpt_path = os.path.join(self.model_path, "causal_model.pt")
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法 # 文件不存在,调用父类的 _load_ckpt 方法
......
...@@ -9,19 +9,7 @@ from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights ...@@ -9,19 +9,7 @@ from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
)
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
class WanDistillModel(WanModel): class WanDistillModel(WanModel):
...@@ -33,7 +21,7 @@ class WanDistillModel(WanModel): ...@@ -33,7 +21,7 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self): def _load_ckpt(self):
use_bfloat16 = self.config.get("use_bfloat16", True) use_bfloat16 = GET_DTYPE() == "BF16"
ckpt_path = os.path.join(self.model_path, "distill_model.pt") ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法 # 文件不存在,调用父类的 _load_ckpt 方法
......
import math import math
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from lightx2v.utils.envs import *
class WanPostInfer: class WanPostInfer:
...@@ -20,8 +21,14 @@ class WanPostInfer: ...@@ -20,8 +21,14 @@ class WanPostInfer:
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e] e = [ei.squeeze(1) for ei in e]
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x) norm_out = weights.norm.apply(x)
if GET_DTYPE() != "BF16":
norm_out = norm_out.float()
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0) out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
if GET_DTYPE() != "BF16":
out = out.to(torch.bfloat16)
x = weights.head.apply(out) x = weights.head.apply(out)
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x] return [u.float() for u in x]
......
import torch import torch
import math
from .utils import rope_params, sinusoidal_embedding_1d from .utils import rope_params, sinusoidal_embedding_1d
from lightx2v.utils.envs import *
class WanPreInfer: class WanPreInfer:
...@@ -60,7 +60,10 @@ class WanPreInfer: ...@@ -60,7 +60,10 @@ class WanPreInfer:
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x]) x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
embed = weights.time_embedding_0.apply(embed) if GET_DTYPE() != "BF16":
embed = weights.time_embedding_0.apply(embed.float())
else:
embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed) embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed) embed = weights.time_embedding_2.apply(embed)
embed0 = torch.nn.functional.silu(embed) embed0 = torch.nn.functional.silu(embed)
...@@ -78,7 +81,10 @@ class WanPreInfer: ...@@ -78,7 +81,10 @@ class WanPreInfer:
# text embeddings # text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
out = weights.text_embedding_0.apply(stacked.squeeze(0)) if GET_DTYPE() != "BF16":
out = weights.text_embedding_0.apply(stacked.squeeze(0).float())
else:
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out) context = weights.text_embedding_2.apply(out)
...@@ -88,7 +94,6 @@ class WanPreInfer: ...@@ -88,7 +94,6 @@ class WanPreInfer:
context_clip = torch.nn.functional.gelu(context_clip, approximate="none") context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = weights.proj_3.apply(context_clip) context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip) context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0) context = torch.concat([context_clip, context], dim=0)
return ( return (
......
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.common.offload.manager import WeightAsyncStreamManager, LazyWeightAsyncStreamManager from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager,
)
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -90,6 +93,7 @@ class WanTransformerInfer: ...@@ -90,6 +93,7 @@ class WanTransformerInfer:
if embed0.dim() == 3: if embed0.dim() == 3:
modulation = weights.blocks[block_idx].modulation.tensor.unsqueeze(2) modulation = weights.blocks[block_idx].modulation.tensor.unsqueeze(2)
current_embed0 = (modulation + embed0).chunk(6, dim=1) current_embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in current_embed0] shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in current_embed0]
elif embed0.dim() == 2: elif embed0.dim() == 2:
...@@ -217,10 +221,18 @@ class WanTransformerInfer: ...@@ -217,10 +221,18 @@ class WanTransformerInfer:
norm1_weight = 1 + scale_msa norm1_weight = 1 + scale_msa
norm1_bias = shift_msa norm1_bias = shift_msa
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm1_out = weights.norm1.apply(x)
if GET_DTYPE() != "BF16":
norm1_out = norm1_out.float()
norm1_out = (norm1_out * norm1_weight + norm1_bias).squeeze(0) norm1_out = (norm1_out * norm1_weight + norm1_bias).squeeze(0)
if GET_DTYPE() != "BF16":
norm1_out = norm1_out.to(torch.bfloat16)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d) q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d)
k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d) k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d)
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
...@@ -257,28 +269,35 @@ class WanTransformerInfer: ...@@ -257,28 +269,35 @@ class WanTransformerInfer:
) )
y = weights.self_attn_o.apply(attn_out) y = weights.self_attn_o.apply(attn_out)
x.add_(y * gate_msa.squeeze(0))
if GET_DTYPE() != "BF16":
x = x.float() + y.float() * gate_msa.squeeze(0)
else:
x.add_(y * gate_msa.squeeze(0))
return x return x
def _infer_cross_attn(self, weights, x, context): def _infer_cross_attn(self, weights, x, context):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == "i2v": if self.task == "i2v":
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
else: else:
context_img = None context_img = None
if GET_DTYPE() != "BF16":
context = context.to(torch.bfloat16)
if self.task == "i2v":
context_img = context_img.to(torch.bfloat16)
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d) q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d) k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d) v = weights.cross_attn_v.apply(context).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q, q,
k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device),
) )
attn_out = weights.cross_attn_1.apply( attn_out = weights.cross_attn_1.apply(
q=q, q=q,
k=k, k=k,
...@@ -309,7 +328,6 @@ class WanTransformerInfer: ...@@ -309,7 +328,6 @@ class WanTransformerInfer:
max_seqlen_kv=k_img.size(0), max_seqlen_kv=k_img.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
attn_out = attn_out + img_attn_out attn_out = attn_out + img_attn_out
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
...@@ -324,11 +342,21 @@ class WanTransformerInfer: ...@@ -324,11 +342,21 @@ class WanTransformerInfer:
norm2_weight = 1 + c_scale_msa.squeeze(0) norm2_weight = 1 + c_scale_msa.squeeze(0)
norm2_bias = c_shift_msa.squeeze(0) norm2_bias = c_shift_msa.squeeze(0)
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm2_out = weights.norm2.apply(x)
y = weights.ffn_0.apply(norm2_out * norm2_weight + norm2_bias) if GET_DTYPE() != "BF16":
norm2_out = norm2_out.float()
norm2_out = norm2_out * norm2_weight + norm2_bias
if GET_DTYPE() != "BF16":
norm2_out = norm2_out.to(torch.bfloat16)
y = weights.ffn_0.apply(norm2_out)
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y) y = weights.ffn_2.apply(y)
x.add_(y * c_gate_msa.squeeze(0))
if GET_DTYPE() != "BF16":
x = x.float() + y.float() * c_gate_msa.squeeze(0)
else:
x.add_(y * c_gate_msa.squeeze(0))
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from lightx2v.utils.envs import *
def compute_freqs(c, grid_sizes, freqs): def compute_freqs(c, grid_sizes, freqs):
...@@ -70,8 +71,12 @@ def apply_rotary_emb(x, freqs_i): ...@@ -70,8 +71,12 @@ def apply_rotary_emb(x, freqs_i):
x_i = torch.view_as_complex(x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) x_i = torch.view_as_complex(x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
# Apply rotary embedding # Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]]).to(torch.bfloat16) x_i = torch.cat([x_i, x[seq_len:]])
return x_i # if GET_DTYPE() == "BF16":
# x_i = x_i.to(torch.bfloat16)
# else:
# x_i = x_i.float()
return x_i.to(torch.bfloat16)
def rope_params(max_seq_len, dim, theta=10000): def rope_params(max_seq_len, dim, theta=10000):
...@@ -92,5 +97,7 @@ def sinusoidal_embedding_1d(dim, position): ...@@ -92,5 +97,7 @@ def sinusoidal_embedding_1d(dim, position):
# calculation # calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(torch.bfloat16) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if GET_DTYPE() == "BF16":
x = x.to(torch.bfloat16)
return x return x
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from safetensors import safe_open from safetensors import safe_open
from loguru import logger from loguru import logger
import gc import gc
from lightx2v.utils.envs import *
class WanLoraWrapper: class WanLoraWrapper:
...@@ -25,7 +26,7 @@ class WanLoraWrapper: ...@@ -25,7 +26,7 @@ class WanLoraWrapper:
return lora_name return lora_name
def _load_lora_file(self, file_path): def _load_lora_file(self, file_path):
use_bfloat16 = True # Default value use_bfloat16 = GET_DTYPE() == "BF16"
if self.model.config and hasattr(self.model.config, "get"): if self.model.config and hasattr(self.model.config, "get"):
use_bfloat16 = self.model.config.get("use_bfloat16", True) use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment