Commit 3e215bad authored by gushiqiao's avatar gushiqiao
Browse files

Support bf16/fp16 inference and mixed-precision inference with fp32 for some layers

parent e684202c
...@@ -3,6 +3,7 @@ import torch.distributed as dist ...@@ -3,6 +3,7 @@ import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger from loguru import logger
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
...@@ -114,7 +115,7 @@ class RingAttnWeight(AttnWeightTemplate): ...@@ -114,7 +115,7 @@ class RingAttnWeight(AttnWeightTemplate):
k = next_k k = next_k
v = next_v v = next_v
attn1 = out.to(torch.bfloat16).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1) attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
if txt_mask_len > 0: if txt_mask_len > 0:
attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward( attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward(
...@@ -131,7 +132,7 @@ class RingAttnWeight(AttnWeightTemplate): ...@@ -131,7 +132,7 @@ class RingAttnWeight(AttnWeightTemplate):
return_softmax=False, return_softmax=False,
) )
attn2 = attn2.to(torch.bfloat16).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1) attn2 = attn2.to(GET_DTYPE()).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn1 = torch.cat([attn1, attn2], dim=0) attn1 = torch.cat([attn1, attn2], dim=0)
return attn1 return attn1
......
...@@ -52,7 +52,7 @@ class SageAttn2Weight(AttnWeightTemplate): ...@@ -52,7 +52,7 @@ class SageAttn2Weight(AttnWeightTemplate):
) )
x = torch.cat((x1, x2), dim=1) x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "wan2.1_audio"]: elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "wan2.1_audio", "wan2.2"]:
x = sageattn( x = sageattn(
q.unsqueeze(0), q.unsqueeze(0),
k.unsqueeze(0), k.unsqueeze(0),
......
...@@ -129,6 +129,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -129,6 +129,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.act_quant_func = None self.act_quant_func = None
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
# ========================= # =========================
# weight load functions # weight load functions
...@@ -139,12 +140,12 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -139,12 +140,12 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory() self.weight = self.lazy_load_file.get_tensor(self.weight_name).pin_memory()
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory() self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float().pin_memory()
if self.bias_name is not None: if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory() self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype).pin_memory()
else: else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name) self.weight = self.lazy_load_file.get_tensor(self.weight_name)
self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float() self.weight_scale = self.lazy_load_file.get_tensor(self.weight_scale_name).float()
if self.bias_name is not None: if self.bias_name is not None:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16) self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
if self.weight_need_transpose: if self.weight_need_transpose:
self.weight = self.weight.t() self.weight = self.weight.t()
...@@ -394,7 +395,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate): ...@@ -394,7 +395,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
self.bias.float(), self.bias.float(),
input_tensor_scale, input_tensor_scale,
self.weight_scale, self.weight_scale,
out_dtype=torch.bfloat16, out_dtype=self.infer_dtype,
) )
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
...@@ -425,7 +426,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate): ...@@ -425,7 +426,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
input_tensor_scale, input_tensor_scale,
self.weight_scale, self.weight_scale,
fuse_gelu=False, fuse_gelu=False,
out_dtype=torch.bfloat16, out_dtype=self.infer_dtype,
) )
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
...@@ -449,7 +450,7 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTempla ...@@ -449,7 +450,7 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTempla
Act Scale: torch.Size([1024, 16]), torch.float32 Act Scale: torch.Size([1024, 16]), torch.float32
Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
Weight Scale: torch.Size([32, 16]), torch.float32 Weight Scale: torch.Size([32, 16]), torch.float32
Out : torch.Size([1024, 4096]), torch.bfloat16 Out : torch.Size([1024, 4096]), self.infer_dtype
""" """
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
...@@ -568,7 +569,7 @@ class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -568,7 +569,7 @@ class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
self.weight, self.weight,
input_tensor_scale, input_tensor_scale,
self.weight_scale, self.weight_scale,
torch.bfloat16, self.infer_dtype,
bias=self.bias, bias=self.bias,
) )
return output_tensor return output_tensor
...@@ -598,7 +599,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate): ...@@ -598,7 +599,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
self.weight, self.weight,
input_tensor_scale, input_tensor_scale,
self.weight_scale, self.weight_scale,
torch.bfloat16, self.infer_dtype,
bias=self.bias, bias=self.bias,
) )
return output_tensor return output_tensor
...@@ -633,7 +634,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -633,7 +634,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
self.weight, self.weight,
input_tensor_scale, input_tensor_scale,
self.weight_scale, self.weight_scale,
torch.bfloat16, self.infer_dtype,
self.bias, self.bias,
) )
return output_tensor return output_tensor
...@@ -659,7 +660,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -659,7 +660,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = input_tensor input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=torch.bfloat16) output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
if self.bias is not None: if self.bias is not None:
output_tensor = output_tensor + self.bias output_tensor = output_tensor + self.bias
......
...@@ -14,6 +14,8 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -14,6 +14,8 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.config = {} self.config = {}
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
...@@ -85,29 +87,30 @@ class LNWeight(LNWeightTemplate): ...@@ -85,29 +87,30 @@ class LNWeight(LNWeightTemplate):
def load_from_disk(self): def load_from_disk(self):
if self.weight_name is not None: if self.weight_name is not None:
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory() self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else: else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16) self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
else: else:
self.weight = None self.weight = None
if self.bias_name is not None: if self.bias_name is not None:
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory() self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE()).pin_memory()
else: else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16) self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE())
else: else:
self.bias = None self.bias = None
def apply(self, input_tensor): def apply(self, input_tensor):
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = torch.nn.functional.layer_norm( input_tensor = torch.nn.functional.layer_norm(
input_tensor.float(), input_tensor.float(),
(input_tensor.shape[-1],), (input_tensor.shape[-1],),
self.weight, self.weight,
self.bias, self.bias,
self.eps, self.eps,
).to(torch.bfloat16) ).to(self.infer_dtype)
else: else:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps) input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor return input_tensor
...@@ -17,6 +17,8 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -17,6 +17,8 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.eps = eps self.eps = eps
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.config = {} self.config = {}
def load(self, weight_dict): def load(self, weight_dict):
...@@ -64,17 +66,17 @@ class RMSWeight(RMSWeightTemplate): ...@@ -64,17 +66,17 @@ class RMSWeight(RMSWeightTemplate):
def load_from_disk(self): def load_from_disk(self):
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory() self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else: else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16) self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
def apply(self, input_tensor): def apply(self, input_tensor):
if GET_DTYPE() == "BF16": if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps) input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight input_tensor = input_tensor * self.weight
else: else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps) input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = (input_tensor * self.weight).to(torch.bfloat16) input_tensor = (input_tensor * self.weight).to(GET_DTYPE())
return input_tensor return input_tensor
...@@ -97,24 +99,23 @@ class RMSWeightSgl(RMSWeight): ...@@ -97,24 +99,23 @@ class RMSWeightSgl(RMSWeight):
def load_from_disk(self): def load_from_disk(self):
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory() self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else: else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16) self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
def apply(self, input_tensor): def apply(self, input_tensor):
use_bf16 = GET_DTYPE() == "BF16" if sgl_kernel is not None and self.sensitive_layer_dtype == self.infer_dtype:
if sgl_kernel is not None and use_bf16:
input_tensor = input_tensor.contiguous() input_tensor = input_tensor.contiguous()
orig_shape = input_tensor.shape orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, orig_shape[-1]) input_tensor = input_tensor.view(-1, orig_shape[-1])
input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape) input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape)
else: else:
# sgl_kernel is not available or dtype!=torch.bfloat16, fallback to default implementation # sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation
if use_bf16: if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).to(self.infer_dtype)
input_tensor = (input_tensor * self.weight).to(self.infer_dtype)
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps) input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight input_tensor = input_tensor * self.weight
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(input_tensor)
input_tensor = (input_tensor * self.weight).type_as(input_tensor)
return input_tensor return input_tensor
...@@ -10,12 +10,14 @@ class DefaultTensor: ...@@ -10,12 +10,14 @@ class DefaultTensor:
self.tensor_name = tensor_name self.tensor_name = tensor_name
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load_from_disk(self): def load_from_disk(self):
if not torch._dynamo.is_compiling(): if not torch._dynamo.is_compiling():
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(torch.bfloat16).pin_memory() self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype).pin_memory()
else: else:
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(torch.bfloat16) self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
def load(self, weight_dict): def load(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
......
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.envs import *
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
...@@ -131,7 +132,7 @@ class T5Attention(nn.Module): ...@@ -131,7 +132,7 @@ class T5Attention(nn.Module):
if hasattr(self, "cpu_offload") and self.cpu_offload: if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn_bias del attn_bias
attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16) attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v) x = torch.einsum("bnij,bjnc->binc", attn, v)
if hasattr(self, "cpu_offload") and self.cpu_offload: if hasattr(self, "cpu_offload") and self.cpu_offload:
...@@ -356,7 +357,7 @@ class T5Encoder(nn.Module): ...@@ -356,7 +357,7 @@ class T5Encoder(nn.Module):
optimize_memory_usage() optimize_memory_usage()
x = self.dropout(x) x = self.dropout(x)
return x.to(torch.bfloat16) return x.to(GET_DTYPE())
class T5Decoder(nn.Module): class T5Decoder(nn.Module):
......
...@@ -12,6 +12,7 @@ from lightx2v.models.networks.cogvideox.infer.transformer_infer import Cogvideox ...@@ -12,6 +12,7 @@ from lightx2v.models.networks.cogvideox.infer.transformer_infer import Cogvideox
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
from lightx2v.utils.envs import *
class CogvideoxModel: class CogvideoxModel:
...@@ -33,7 +34,7 @@ class CogvideoxModel: ...@@ -33,7 +34,7 @@ class CogvideoxModel:
def _load_safetensor_to_dict(self, file_path): def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).cuda() for key in f.keys()}
return tensor_dict return tensor_dict
def _load_ckpt(self): def _load_ckpt(self):
......
...@@ -3,6 +3,8 @@ import math ...@@ -3,6 +3,8 @@ import math
import torch import torch
from einops import rearrange from einops import rearrange
from lightx2v.utils.envs import *
class HunyuanPreInfer: class HunyuanPreInfer:
def __init__(self, config): def __init__(self, config):
...@@ -64,7 +66,7 @@ class HunyuanPreInfer: ...@@ -64,7 +66,7 @@ class HunyuanPreInfer:
def infer_time_in(self, weights, t): def infer_time_in(self, weights, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device) freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None] args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.time_in_mlp_0.apply(embedding) out = weights.time_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out) out = torch.nn.functional.silu(out)
out = weights.time_in_mlp_2.apply(out) out = weights.time_in_mlp_2.apply(out)
...@@ -78,12 +80,12 @@ class HunyuanPreInfer: ...@@ -78,12 +80,12 @@ class HunyuanPreInfer:
def infer_text_in(self, weights, text_states, text_mask, t): def infer_text_in(self, weights, text_states, text_mask, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device) freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None] args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.txt_in_t_embedder_mlp_0.apply(embedding) out = weights.txt_in_t_embedder_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out) out = torch.nn.functional.silu(out)
timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out) timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out)
mask_float = text_mask.float().unsqueeze(-1).to(torch.bfloat16) # [b, s1, 1] mask_float = text_mask.float().unsqueeze(-1).to(GET_DTYPE()) # [b, s1, 1]
context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = context_aware_representations context_aware_representations = context_aware_representations
...@@ -148,7 +150,7 @@ class HunyuanPreInfer: ...@@ -148,7 +150,7 @@ class HunyuanPreInfer:
def infer_guidance_in(self, weights, guidance): def infer_guidance_in(self, weights, guidance):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=guidance.device) freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=guidance.device)
args = guidance.float() * freqs[None] args = guidance.float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=torch.bfloat16) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.guidance_in_mlp_0.apply(embedding) out = weights.guidance_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out) out = torch.nn.functional.silu(out)
out = weights.guidance_in_mlp_2.apply(out) out = weights.guidance_in_mlp_2.apply(out)
......
...@@ -2,11 +2,13 @@ from typing import Tuple, Union ...@@ -2,11 +2,13 @@ from typing import Tuple, Union
import torch import torch
from lightx2v.utils.envs import *
def rms_norm(x, weight, eps): def rms_norm(x, weight, eps):
x = x.float() x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.to(torch.bfloat16) x = x.to(GET_DTYPE())
x = x * weight x = x * weight
return x return x
...@@ -18,7 +20,7 @@ def rotate_half(x, shape_0, shape_1): ...@@ -18,7 +20,7 @@ def rotate_half(x, shape_0, shape_1):
def rotary_emb(x, shape_0, shape_1, cos, sin): def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
return x_out.to(torch.bfloat16) return x_out.to(GET_DTYPE())
def apply_rotary_emb( def apply_rotary_emb(
......
...@@ -78,7 +78,7 @@ class HunyuanModel: ...@@ -78,7 +78,7 @@ class HunyuanModel:
for k in f.keys(): for k in f.keys():
weight_dict[k] = f.get_tensor(k) weight_dict[k] = f.get_tensor(k)
if weight_dict[k].dtype == torch.float: if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].to(torch.bfloat16) weight_dict[k] = weight_dict[k].to(GET_DTYPE())
return weight_dict return weight_dict
......
...@@ -13,6 +13,8 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps ...@@ -13,6 +13,8 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from transformers import AutoModel from transformers import AutoModel
from lightx2v.utils.envs import *
def load_safetensors(in_path: str): def load_safetensors(in_path: str):
if os.path.isdir(in_path): if os.path.isdir(in_path):
...@@ -57,7 +59,7 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True): ...@@ -57,7 +59,7 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()
return model.to(dtype=torch.bfloat16, device="cuda") return model.to(dtype=GET_DTYPE(), device="cuda")
def linear_interpolation(features, output_len: int): def linear_interpolation(features, output_len: int):
......
...@@ -67,10 +67,10 @@ class WanAudioModel(WanModel): ...@@ -67,10 +67,10 @@ class WanAudioModel(WanModel):
class Wan22MoeAudioModel(WanAudioModel): class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, use_bf16, skip_bf16) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
...@@ -31,23 +31,25 @@ class WanCausVidModel(WanModel): ...@@ -31,23 +31,25 @@ class WanCausVidModel(WanModel):
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanTransformerInferCausVid self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_folder = "causvid_models" ckpt_folder = "causvid_models"
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors") safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if os.path.exists(safetensors_path): if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f: with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()
}
return weight_dict return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt") ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = { weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys() key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
} }
return weight_dict return weight_dict
return super()._load_ckpt(use_bf16, skip_bf16) return super()._load_ckpt(unified_dtype, sensitive_layer)
@torch.no_grad() @torch.no_grad()
def infer(self, inputs, kv_start, kv_end): def infer(self, inputs, kv_start, kv_end):
......
...@@ -20,27 +20,27 @@ class WanDistillModel(WanModel): ...@@ -20,27 +20,27 @@ class WanDistillModel(WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("enable_dynamic_cfg", False): if self.config.get("enable_dynamic_cfg", False):
ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors") ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors")
else: else:
ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors") ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}") logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16) return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer)
return super()._load_ckpt(use_bf16, skip_bf16) return super()._load_ckpt(unified_dtype, sensitive_layer)
class Wan22MoeDistillModel(WanDistillModel, Wan22MoeModel): class Wan22MoeDistillModel(WanDistillModel, Wan22MoeModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
WanDistillModel.__init__(self, model_path, config, device) WanDistillModel.__init__(self, model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_path = os.path.join(self.model_path, "distill_model.safetensors") ckpt_path = os.path.join(self.model_path, "distill_model.safetensors")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}") logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, use_bf16, skip_bf16) return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer)
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
......
...@@ -54,7 +54,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -54,7 +54,7 @@ class WanAudioPreInfer(WanPreInfer):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"] ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"]
batch_size = len(x) batch_size = len(x)
num_channels, num_frames, height, width = x[0].shape num_channels, _, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape _, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels: if ref_num_channels != num_channels:
......
...@@ -10,6 +10,8 @@ class WanPostInfer: ...@@ -10,6 +10,8 @@ class WanPostInfer:
self.out_dim = config["out_dim"] self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False) self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -26,11 +28,11 @@ class WanPostInfer: ...@@ -26,11 +28,11 @@ class WanPostInfer:
x = weights.norm.apply(x) x = weights.norm.apply(x)
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
x = x.float() x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze()) x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(torch.bfloat16) x = x.to(self.infer_dtype)
x = weights.head.apply(x) x = weights.head.apply(x)
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
......
...@@ -25,6 +25,8 @@ class WanPreInfer: ...@@ -25,6 +25,8 @@ class WanPreInfer:
self.text_len = config["text_len"] self.text_len = config["text_len"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False) self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
self.cfg_scale = config.get("cfg_scale", 4.0) self.cfg_scale = config.get("cfg_scale", 4.0)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -79,8 +81,8 @@ class WanPreInfer: ...@@ -79,8 +81,8 @@ class WanPreInfer:
cfg_embed = torch.nn.functional.silu(cfg_embed) cfg_embed = torch.nn.functional.silu(cfg_embed)
cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed) cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed)
embed = embed + cfg_embed embed = embed + cfg_embed
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
embed = weights.time_embedding_0.apply(embed.float()) embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
else: else:
embed = weights.time_embedding_0.apply(embed) embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed) embed = torch.nn.functional.silu(embed)
...@@ -100,8 +102,8 @@ class WanPreInfer: ...@@ -100,8 +102,8 @@ class WanPreInfer:
# text embeddings # text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
out = weights.text_embedding_0.apply(stacked.squeeze(0).float()) out = weights.text_embedding_0.apply(stacked.squeeze(0).to(self.sensitive_layer_dtype))
else: else:
out = weights.text_embedding_0.apply(stacked.squeeze(0)) out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
......
...@@ -30,6 +30,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -30,6 +30,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.apply_rotary_emb_func = apply_rotary_emb self.apply_rotary_emb_func = apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None self.mask_map = None
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0): if torch.cuda.get_device_capability(0) == (9, 0):
...@@ -342,13 +344,13 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -342,13 +344,13 @@ class WanTransformerInfer(BaseTransformerInfer):
norm1_out = weights.norm1.apply(x) norm1_out = weights.norm1.apply(x)
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.float() norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out.mul_(norm1_weight).add_(norm1_bias) norm1_out.mul_(norm1_weight).add_(norm1_bias)
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(torch.bfloat16) norm1_out = norm1_out.to(self.infer_dtype)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
...@@ -402,8 +404,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -402,8 +404,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return y return y
def infer_cross_attn(self, weights, x, context, y_out, gate_msa): def infer_cross_attn(self, weights, x, context, y_out, gate_msa):
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
x = x.float() + y_out.float() * gate_msa.squeeze() x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze()
else: else:
x.add_(y_out * gate_msa.squeeze()) x.add_(y_out * gate_msa.squeeze())
...@@ -414,10 +416,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -414,10 +416,10 @@ class WanTransformerInfer(BaseTransformerInfer):
else: else:
context_img = None context_img = None
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
context = context.to(torch.bfloat16) context = context.to(self.infer_dtype)
if self.task == "i2v" and self.config.get("use_image_encoder", True): if self.task == "i2v" and self.config.get("use_image_encoder", True):
context_img = context_img.to(torch.bfloat16) context_img = context_img.to(self.infer_dtype)
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
...@@ -485,11 +487,11 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -485,11 +487,11 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_bias = c_shift_msa.squeeze() norm2_bias = c_shift_msa.squeeze()
norm2_out = weights.norm2.apply(x) norm2_out = weights.norm2.apply(x)
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.float() norm2_out = norm2_out.to(self.sensitive_layer_dtype)
norm2_out.mul_(norm2_weight).add_(norm2_bias) norm2_out.mul_(norm2_weight).add_(norm2_bias)
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
norm2_out = norm2_out.to(torch.bfloat16) norm2_out = norm2_out.to(self.infer_dtype)
y = weights.ffn_0.apply(norm2_out) y = weights.ffn_0.apply(norm2_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
...@@ -503,8 +505,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -503,8 +505,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return y return y
def post_process(self, x, y, c_gate_msa): def post_process(self, x, y, c_gate_msa):
if GET_DTYPE() != "BF16": if self.sensitive_layer_dtype != self.infer_dtype:
x = x.float() + y.float() * c_gate_msa.squeeze() x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else: else:
x.add_(y * c_gate_msa.squeeze()) x.add_(y * c_gate_msa.squeeze())
......
...@@ -68,7 +68,7 @@ def apply_rotary_emb(x, freqs_i): ...@@ -68,7 +68,7 @@ def apply_rotary_emb(x, freqs_i):
# Apply rotary embedding # Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]]) x_i = torch.cat([x_i, x[seq_len:]])
return x_i.to(torch.bfloat16) return x_i.to(GET_DTYPE())
def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100): def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
...@@ -82,7 +82,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100): ...@@ -82,7 +82,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
freqs_chunk = freqs_i[start:end] freqs_chunk = freqs_i[start:end]
x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2)) x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2))
x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(torch.bfloat16) x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(GET_DTYPE())
output_chunks.append(x_chunk_embedded) output_chunks.append(x_chunk_embedded)
del x_chunk_complex, x_chunk_embedded del x_chunk_complex, x_chunk_embedded
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -101,7 +101,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100): ...@@ -101,7 +101,7 @@ def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
del result del result
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x_i.to(torch.bfloat16) return x_i.to(GET_DTYPE())
def rope_params(max_seq_len, dim, theta=10000): def rope_params(max_seq_len, dim, theta=10000):
...@@ -123,8 +123,7 @@ def sinusoidal_embedding_1d(dim, position): ...@@ -123,8 +123,7 @@ def sinusoidal_embedding_1d(dim, position):
# calculation # calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if GET_DTYPE() == "BF16": x = x.to(GET_SENSITIVE_DTYPE())
x = x.to(torch.bfloat16)
return x return x
...@@ -140,15 +139,15 @@ def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_ ...@@ -140,15 +139,15 @@ def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_
""" """
assert len(w.shape) == 1 assert len(w.shape) == 1
cfg_min, cfg_max = cfg_range cfg_min, cfg_max = cfg_range
# w = torch.round(w) w = torch.round(w)
# w = torch.clamp(w, min=cfg_min, max=cfg_max) w = torch.clamp(w, min=cfg_min, max=cfg_max)
w = (w - cfg_min) / (cfg_max - cfg_min) # [0, 1] w = (w - cfg_min) / (cfg_max - cfg_min) # [0, 1]
w = w * target_range w = w * target_range
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device) emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device)
emb = w.to(dtype)[:, None] * emb[None, :] emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1).to(w.device)) emb = torch.nn.functional.pad(emb, (0, 1).to(w.device))
assert emb.shape == (w.shape[0], embedding_dim) assert emb.shape == (w.shape[0], embedding_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment