Commit 3e215bad authored by gushiqiao's avatar gushiqiao
Browse files

Support bf16/fp16 inference and mixed-precision inference with fp32 for some layers

parent e684202c
......@@ -3,6 +3,7 @@ import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
......@@ -114,7 +115,7 @@ class RingAttnWeight(AttnWeightTemplate):
k = next_k
v = next_v
attn1 = out.to(torch.bfloat16).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
if txt_mask_len > 0:
attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward(
......@@ -131,7 +132,7 @@ class RingAttnWeight(AttnWeightTemplate):
return_softmax=False,
)
attn2 = attn2.to(torch.bfloat16).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn2 = attn2.to(GET_DTYPE()).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn1 = torch.cat([attn1, attn2], dim=0)
return attn1
......
......@@ -52,7 +52,7 @@ class SageAttn2Weight(AttnWeightTemplate):
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "wan2.1_audio"]:
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "wan2.1_audio", "wan2.2"]:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
......
......@@ -129,6 +129,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.act_quant_func = None
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
# =========================
# weight load functions
......@@ -139,12 +140,12 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name)
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
if self.weight_need_transpose:
self.weight = self.weight.t()
......@@ -394,7 +395,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
self.bias.float(),
input_tensor_scale,
self.weight_scale,
out_dtype=torch.bfloat16,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0)
......@@ -425,7 +426,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
input_tensor_scale,
self.weight_scale,
fuse_gelu=False,
out_dtype=torch.bfloat16,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0)
......@@ -449,7 +450,7 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTempla
Act Scale: torch.Size([1024, 16]), torch.float32
Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
Weight Scale: torch.Size([32, 16]), torch.float32
Out : torch.Size([1024, 4096]), torch.bfloat16
Out : torch.Size([1024, 4096]), self.infer_dtype
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
......@@ -568,7 +569,7 @@ class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
self.weight,
input_tensor_scale,
self.weight_scale,
torch.bfloat16,
self.infer_dtype,
bias=self.bias,
)
return output_tensor
......@@ -598,7 +599,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
self.weight,
input_tensor_scale,
self.weight_scale,
torch.bfloat16,
self.infer_dtype,
bias=self.bias,
)
return output_tensor
......@@ -633,7 +634,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
self.weight,
input_tensor_scale,
self.weight_scale,
torch.bfloat16,
self.infer_dtype,
self.bias,
)
return output_tensor
......@@ -659,7 +660,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=torch.bfloat16)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
if self.bias is not None:
output_tensor = output_tensor + self.bias
......
......@@ -14,6 +14,8 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {}
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load(self, weight_dict):
if not self.lazy_load:
......@@ -85,29 +87,30 @@ class LNWeight(LNWeightTemplate):
def load_from_disk(self):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
else:
self.weight = None
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE()).pin_memory()
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE())
else:
self.bias = None
def apply(self, input_tensor):
if GET_DTYPE() != "BF16":
if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = torch.nn.functional.layer_norm(
input_tensor.float(),
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
).to(torch.bfloat16)
).to(self.infer_dtype)
else:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor
......@@ -17,6 +17,8 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.eps = eps
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.config = {}
def load(self, weight_dict):
......@@ -64,17 +66,17 @@ class RMSWeight(RMSWeightTemplate):
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
def apply(self, input_tensor):
if GET_DTYPE() == "BF16":
if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = (input_tensor * self.weight).to(torch.bfloat16)
input_tensor = (input_tensor * self.weight).to(GET_DTYPE())
return input_tensor
......@@ -97,24 +99,23 @@ class RMSWeightSgl(RMSWeight):
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
def apply(self, input_tensor):
use_bf16 = GET_DTYPE() == "BF16"
if sgl_kernel is not None and use_bf16:
if sgl_kernel is not None and self.sensitive_layer_dtype == self.infer_dtype:
input_tensor = input_tensor.contiguous()
orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, orig_shape[-1])
input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape)
else:
# sgl_kernel is not available or dtype!=torch.bfloat16, fallback to default implementation
if use_bf16:
# sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation
if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).to(self.infer_dtype)
input_tensor = (input_tensor * self.weight).to(self.infer_dtype)
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(input_tensor)
input_tensor = (input_tensor * self.weight).type_as(input_tensor)
return input_tensor
......@@ -10,12 +10,14 @@ class DefaultTensor:
self.tensor_name = tensor_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(torch.bfloat16).pin_memory()
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype).pin_memory()
else:
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(torch.bfloat16)
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
def load(self, weight_dict):
if not self.lazy_load:
......
......@@ -9,6 +9,7 @@ import torch.nn.functional as F
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.envs import *
from .tokenizer import HuggingfaceTokenizer
......@@ -131,7 +132,7 @@ class T5Attention(nn.Module):
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn_bias
attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v)
if hasattr(self, "cpu_offload") and self.cpu_offload:
......@@ -356,7 +357,7 @@ class T5Encoder(nn.Module):
optimize_memory_usage()
x = self.dropout(x)
return x.to(torch.bfloat16)
return x.to(GET_DTYPE())
class T5Decoder(nn.Module):
......
......@@ -12,6 +12,7 @@ from lightx2v.models.networks.cogvideox.infer.transformer_infer import Cogvideox
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
from lightx2v.utils.envs import *
class CogvideoxModel:
......@@ -33,7 +34,7 @@ class CogvideoxModel:
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self):
......
......@@ -3,6 +3,8 @@ import math
import torch
from einops import rearrange
from lightx2v.utils.envs import *
class HunyuanPreInfer:
def __init__(self, config):
......@@ -64,7 +66,7 @@ class HunyuanPreInfer:
def infer_time_in(self, weights, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.time_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
out = weights.time_in_mlp_2.apply(out)
......@@ -78,12 +80,12 @@ class HunyuanPreInfer:
def infer_text_in(self, weights, text_states, text_mask, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.txt_in_t_embedder_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out)
mask_float = text_mask.float().unsqueeze(-1).to(torch.bfloat16) # [b, s1, 1]
mask_float = text_mask.float().unsqueeze(-1).to(GET_DTYPE()) # [b, s1, 1]
context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = context_aware_representations
......@@ -148,7 +150,7 @@ class HunyuanPreInfer:
def infer_guidance_in(self, weights, guidance):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=guidance.device)
args = guidance.float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16)
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.guidance_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
out = weights.guidance_in_mlp_2.apply(out)
......
......@@ -2,11 +2,13 @@ from typing import Tuple, Union
import torch
from lightx2v.utils.envs import *
def rms_norm(x, weight, eps):
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.to(torch.bfloat16)
x = x.to(GET_DTYPE())
x = x * weight
return x
......@@ -18,7 +20,7 @@ def rotate_half(x, shape_0, shape_1):
def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
return x_out.to(torch.bfloat16)
return x_out.to(GET_DTYPE())
def apply_rotary_emb(
......
......@@ -78,7 +78,7 @@ class HunyuanModel:
for k in f.keys():
weight_dict[k] = f.get_tensor(k)
if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].to(torch.bfloat16)
weight_dict[k] = weight_dict[k].to(GET_DTYPE())
return weight_dict
......
......@@ -13,6 +13,8 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
from lightx2v.utils.envs import *
def load_safetensors(in_path: str):
if os.path.isdir(in_path):
......@@ -57,7 +59,7 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
model.load_state_dict(state_dict, strict=strict)
if dist.is_initialized():
dist.barrier()
return model.to(dtype=torch.bfloat16, device="cuda")
return model.to(dtype=GET_DTYPE(), device="cuda")
def linear_interpolation(features, output_len: int):
......
......@@ -67,10 +67,10 @@ class WanAudioModel(WanModel):
class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, use_bf16, skip_bf16):
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16)
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
......@@ -31,23 +31,25 @@ class WanCausVidModel(WanModel):
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, use_bf16, skip_bf16):
def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_folder = "causvid_models"
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()
}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
return weight_dict
return super()._load_ckpt(use_bf16, skip_bf16)
return super()._load_ckpt(unified_dtype, sensitive_layer)
@torch.no_grad()
def infer(self, inputs, kv_start, kv_end):
......
......@@ -20,27 +20,27 @@ class WanDistillModel(WanModel):
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16):
def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("enable_dynamic_cfg", False):
ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors")
else:
ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16)
return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer)
return super()._load_ckpt(use_bf16, skip_bf16)
return super()._load_ckpt(unified_dtype, sensitive_layer)
class Wan22MoeDistillModel(WanDistillModel, Wan22MoeModel):
def __init__(self, model_path, config, device):
WanDistillModel.__init__(self, model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16):
def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_path = os.path.join(self.model_path, "distill_model.safetensors")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16)
return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer)
@torch.no_grad()
def infer(self, inputs):
......
......@@ -54,7 +54,7 @@ class WanAudioPreInfer(WanPreInfer):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"]
batch_size = len(x)
num_channels, num_frames, height, width = x[0].shape
num_channels, _, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels:
......
......@@ -10,6 +10,8 @@ class WanPostInfer:
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -26,11 +28,11 @@ class WanPostInfer:
x = weights.norm.apply(x)
if GET_DTYPE() != "BF16":
x = x.float()
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if GET_DTYPE() != "BF16":
x = x.to(torch.bfloat16)
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
x = self.unpatchify(x, grid_sizes)
......
......@@ -25,6 +25,8 @@ class WanPreInfer:
self.text_len = config["text_len"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
self.cfg_scale = config.get("cfg_scale", 4.0)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -79,8 +81,8 @@ class WanPreInfer:
cfg_embed = torch.nn.functional.silu(cfg_embed)
cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed)
embed = embed + cfg_embed
if GET_DTYPE() != "BF16":
embed = weights.time_embedding_0.apply(embed.float())
if self.sensitive_layer_dtype != self.infer_dtype:
embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
else:
embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed)
......@@ -100,8 +102,8 @@ class WanPreInfer:
# text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
if GET_DTYPE() != "BF16":
out = weights.text_embedding_0.apply(stacked.squeeze(0).float())
if self.sensitive_layer_dtype != self.infer_dtype:
out = weights.text_embedding_0.apply(stacked.squeeze(0).to(self.sensitive_layer_dtype))
else:
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
......
......@@ -30,6 +30,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.apply_rotary_emb_func = apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0):
......@@ -342,13 +344,13 @@ class WanTransformerInfer(BaseTransformerInfer):
norm1_out = weights.norm1.apply(x)
if GET_DTYPE() != "BF16":
norm1_out = norm1_out.float()
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out.mul_(norm1_weight).add_(norm1_bias)
if GET_DTYPE() != "BF16":
norm1_out = norm1_out.to(torch.bfloat16)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.infer_dtype)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
......@@ -402,8 +404,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return y
def infer_cross_attn(self, weights, x, context, y_out, gate_msa):
if GET_DTYPE() != "BF16":
x = x.float() + y_out.float() * gate_msa.squeeze()
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze()
else:
x.add_(y_out * gate_msa.squeeze())
......@@ -414,10 +416,10 @@ class WanTransformerInfer(BaseTransformerInfer):
else:
context_img = None
if GET_DTYPE() != "BF16":
context = context.to(torch.bfloat16)
if self.sensitive_layer_dtype != self.infer_dtype:
context = context.to(self.infer_dtype)
if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_img = context_img.to(torch.bfloat16)
context_img = context_img.to(self.infer_dtype)
n, d = self.num_heads, self.head_dim
......@@ -485,11 +487,11 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_bias = c_shift_msa.squeeze()
norm2_out = weights.norm2.apply(x)
if GET_DTYPE() != "BF16":
norm2_out = norm2_out.float()
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out.mul_(norm2_weight).add_(norm2_bias)
if GET_DTYPE() != "BF16":
norm2_out = norm2_out.to(torch.bfloat16)
if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(self.infer_dtype)
y = weights.ffn_0.apply(norm2_out)
if self.clean_cuda_cache:
......@@ -503,8 +505,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return y
def post_process(self, x, y, c_gate_msa):
if GET_DTYPE() != "BF16":
x = x.float() + y.float() * c_gate_msa.squeeze()
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else:
x.add_(y * c_gate_msa.squeeze())
......
......@@ -68,7 +68,7 @@ def apply_rotary_emb(x, freqs_i):
# Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]])
return x_i.to(torch.bfloat16)
return x_i.to(GET_DTYPE())
def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
......@@ -82,7 +82,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
freqs_chunk = freqs_i[start:end]
x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2))
x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(torch.bfloat16)
x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(GET_DTYPE())
output_chunks.append(x_chunk_embedded)
del x_chunk_complex, x_chunk_embedded
torch.cuda.empty_cache()
......@@ -101,7 +101,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
del result
torch.cuda.empty_cache()
return x_i.to(torch.bfloat16)
return x_i.to(GET_DTYPE())
def rope_params(max_seq_len, dim, theta=10000):
......@@ -123,8 +123,7 @@ def sinusoidal_embedding_1d(dim, position):
# calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if GET_DTYPE() == "BF16":
x = x.to(torch.bfloat16)
x = x.to(GET_SENSITIVE_DTYPE())
return x
......@@ -140,15 +139,15 @@ def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_
"""
assert len(w.shape) == 1
cfg_min, cfg_max = cfg_range
# w = torch.round(w)
# w = torch.clamp(w, min=cfg_min, max=cfg_max)
w = torch.round(w)
w = torch.clamp(w, min=cfg_min, max=cfg_max)
w = (w - cfg_min) / (cfg_max - cfg_min) # [0, 1]
w = w * target_range
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1).to(w.device))
assert emb.shape == (w.shape[0], embedding_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment