Unverified Commit f21da849 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files
parent 3efc43f5
...@@ -5,14 +5,15 @@ from lightx2v.utils.envs import * ...@@ -5,14 +5,15 @@ from lightx2v.utils.envs import *
class SekoAudioEncoderModel: class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr, cpu_offload, device): def __init__(self, model_path, audio_sr, cpu_offload, run_device):
self.model_path = model_path self.model_path = model_path
self.audio_sr = audio_sr self.audio_sr = audio_sr
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device(device) self.device = torch.device(run_device)
self.run_device = run_device
self.load() self.load()
def load(self): def load(self):
...@@ -26,13 +27,13 @@ class SekoAudioEncoderModel: ...@@ -26,13 +27,13 @@ class SekoAudioEncoderModel:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
def to_cuda(self): def to_cuda(self):
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda") self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
@torch.no_grad() @torch.no_grad()
def infer(self, audio_segment): def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.device).to(dtype=GET_DTYPE()) audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.run_device).to(dtype=GET_DTYPE())
if self.cpu_offload: if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda") self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
if self.cpu_offload: if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu") self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
......
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# 1. 标准库导入
import gc import gc
import math import math
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
# 2. 第三方库导入
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -747,6 +745,7 @@ class T5EncoderModel: ...@@ -747,6 +745,7 @@ class T5EncoderModel:
text_len, text_len,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=torch.device("cuda"), device=torch.device("cuda"),
run_device=torch.device("cuda"),
checkpoint_path=None, checkpoint_path=None,
tokenizer_path=None, tokenizer_path=None,
shard_fn=None, shard_fn=None,
...@@ -759,6 +758,7 @@ class T5EncoderModel: ...@@ -759,6 +758,7 @@ class T5EncoderModel:
self.text_len = text_len self.text_len = text_len
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.run_device = run_device
if t5_quantized_ckpt is not None and t5_quantized: if t5_quantized_ckpt is not None and t5_quantized:
self.checkpoint_path = t5_quantized_ckpt self.checkpoint_path = t5_quantized_ckpt
else: else:
...@@ -807,8 +807,8 @@ class T5EncoderModel: ...@@ -807,8 +807,8 @@ class T5EncoderModel:
def infer(self, texts): def infer(self, texts):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(self.device) ids = ids.to(self.run_device)
mask = mask.to(self.device) mask = mask.to(self.run_device)
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
with torch.no_grad(): with torch.no_grad():
......
...@@ -292,7 +292,7 @@ class VisionTransformer(nn.Module): ...@@ -292,7 +292,7 @@ class VisionTransformer(nn.Module):
b = x.size(0) b = x.size(0)
# embeddings # embeddings
x = self.patch_embedding(x.type(self.patch_embedding.weight.type())).flatten(2).permute(0, 2, 1) x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ("token", "token_fc"): if self.pool_type in ("token", "token_fc"):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation: if interpolation:
...@@ -426,9 +426,10 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r ...@@ -426,9 +426,10 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel: class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False): def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, run_device=torch.device("cuda")):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.run_device = run_device
self.quantized = clip_quantized self.quantized = clip_quantized
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.use_31_block = use_31_block self.use_31_block = use_31_block
...@@ -462,7 +463,7 @@ class CLIPModel: ...@@ -462,7 +463,7 @@ class CLIPModel:
return out return out
def to_cuda(self): def to_cuda(self):
self.model = self.model.cuda() self.model = self.model.to(self.run_device)
def to_cpu(self): def to_cpu(self):
self.model = self.model.cpu() self.model = self.model.cpu()
import torch
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
from loguru import logger
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
flash_attn_varlen_func_v3 = None
logger.info("flash_attn_varlen_func_v3 not available")
if torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]:
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
else:
try:
from sageattention import sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
try:
from sageattn3 import sageattn3_blackwell
except ImportError:
logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None
def flash_attn_no_pad(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_seqlens,
max_s,
dropout_p,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic,
)
output = rearrange(
pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen),
"b s (h d) -> b s h d",
h=nheads,
)
return output
def flash_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
if flash_attn_varlen_func_v3 is None:
raise ImportError("FlashAttention V3 backend not available")
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads)
output_unpad = flash_attn_varlen_func_v3(
query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_q, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic
)
output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads)
return output
def sage_attn_no_pad_v2(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads)
output_unpad = sageattn(
query_unpad.unsqueeze(0),
key_unpad.unsqueeze(0),
value_unpad.unsqueeze(0),
tensor_layout="NHD",
).squeeze(0)
output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads)
return output
def sage_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads)
output_unpad = sageattn3_blackwell(query_unpad.unsqueeze(0).transpose(1, 2), key_unpad.unsqueeze(0).transpose(1, 2), value_unpad.unsqueeze(0).transpose(1, 2)).transpose(1, 2).squeeze(0)
output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads)
return output
import gc
import json
import numpy as np
import torch
import torch.nn.functional as F
from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer
class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.magcache_thresh = config.get("magcache_thresh", 0.2)
self.K = config.get("magcache_K", 6)
self.retention_ratio = config.get("magcache_retention_ratio", 0.2)
self.mag_ratios = np.array(config.get("magcache_ratios", []))
self.enable_magcache_calibration = config.get("magcache_calibration", True)
# {True: cond_param, False: uncond_param}
self.accumulated_err = {True: 0.0, False: 0.0}
self.accumulated_steps = {True: 0, False: 0}
self.accumulated_ratio = {True: 1.0, False: 1.0}
self.residual_cache = {True: None, False: None}
self.residual_cache_txt = {True: None, False: None}
# calibration args
self.norm_ratio = [[1.0], [1.0]] # mean of magnitude ratio
self.norm_std = [[0.0], [0.0]] # std of magnitude ratio
self.cos_dis = [[0.0], [0.0]] # cosine distance of residual features
@torch.no_grad()
def infer(self, weights, infer_module_out):
skip_forward = False
step_index = self.scheduler.step_index
infer_condition = self.scheduler.infer_condition
if self.enable_magcache_calibration:
skip_forward = False
else:
if step_index >= int(self.config["infer_steps"] * self.retention_ratio):
# conditional and unconditional in one list
cur_mag_ratio = self.mag_ratios[0][step_index] if infer_condition else self.mag_ratios[1][step_index]
# magnitude ratio between current step and the cached step
self.accumulated_ratio[infer_condition] = self.accumulated_ratio[infer_condition] * cur_mag_ratio
self.accumulated_steps[infer_condition] += 1 # skip steps plus 1
# skip error of current steps
cur_skip_err = np.abs(1 - self.accumulated_ratio[infer_condition])
# accumulated error of multiple steps
self.accumulated_err[infer_condition] += cur_skip_err
if self.accumulated_err[infer_condition] < self.magcache_thresh and self.accumulated_steps[infer_condition] <= self.K:
skip_forward = True
else:
self.accumulated_err[infer_condition] = 0
self.accumulated_steps[infer_condition] = 0
self.accumulated_ratio[infer_condition] = 1.0
if not skip_forward:
self.infer_calculating(weights, infer_module_out)
else:
self.infer_using_cache(infer_module_out)
x = self.infer_final_layer(weights, infer_module_out)
return x
def infer_calculating(self, weights, infer_module_out):
step_index = self.scheduler.step_index
infer_condition = self.scheduler.infer_condition
ori_img = infer_module_out.img.clone()
ori_txt = infer_module_out.txt.clone()
self.infer_func(weights, infer_module_out)
previous_residual = infer_module_out.img - ori_img
previous_residual_txt = infer_module_out.txt - ori_txt
if self.config["cpu_offload"]:
previous_residual = previous_residual.cpu()
previous_residual_txt = previous_residual_txt.cpu()
if self.enable_magcache_calibration and step_index >= 1:
norm_ratio = ((previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).mean()).item()
norm_std = (previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).std().item()
cos_dis = (1 - F.cosine_similarity(previous_residual, self.residual_cache[infer_condition], dim=-1, eps=1e-8)).mean().item()
_index = int(not infer_condition)
self.norm_ratio[_index].append(round(norm_ratio, 5))
self.norm_std[_index].append(round(norm_std, 5))
self.cos_dis[_index].append(round(cos_dis, 5))
print(f"time: {step_index}, infer_condition: {infer_condition}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}")
self.residual_cache[infer_condition] = previous_residual
self.residual_cache_txt[infer_condition] = previous_residual_txt
if self.config["cpu_offload"]:
ori_img = ori_img.to("cpu")
ori_txt = ori_txt.to("cpu")
del ori_img, ori_txt
torch.cuda.empty_cache()
gc.collect()
def infer_using_cache(self, infer_module_out):
residual_img = self.residual_cache[self.scheduler.infer_condition]
residual_txt = self.residual_cache_txt[self.scheduler.infer_condition]
infer_module_out.img.add_(residual_img.cuda())
infer_module_out.txt.add_(residual_txt.cuda())
def clear(self):
self.accumulated_err = {True: 0.0, False: 0.0}
self.accumulated_steps = {True: 0, False: 0}
self.accumulated_ratio = {True: 1.0, False: 1.0}
self.residual_cache = {True: None, False: None}
self.residual_cache_txt = {True: None, False: None}
if self.enable_magcache_calibration:
print("norm ratio")
print(self.norm_ratio)
print("norm std")
print(self.norm_std)
print("cos_dis")
print(self.cos_dis)
def save_json(filename, obj_list):
with open(filename + ".json", "w") as f:
json.dump(obj_list, f)
save_json("mag_ratio", self.norm_ratio)
save_json("mag_std", self.norm_std)
save_json("cos_dis", self.cos_dis)
class HunyuanTransformerInferTeaCaching(HunyuanVideo15OffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.teacache_thresh = self.config["teacache_thresh"]
self.coefficients = self.config["coefficients"]
self.accumulated_rel_l1_distance_odd = 0
self.previous_modulated_input_odd = None
self.previous_residual_odd = None
self.accumulated_rel_l1_distance_even = 0
self.previous_modulated_input_even = None
self.previous_residual_even = None
def calculate_should_calc(self, img, vec, block):
inp = img.clone()
vec_ = vec.clone()
img_mod_layer = block.img_branch.img_mod
if self.config["cpu_offload"]:
img_mod_layer.to_cuda()
img_mod1_shift, img_mod1_scale, _, _, _, _ = img_mod_layer.apply(vec_).chunk(6, dim=-1)
inp = inp.squeeze(0)
normed_inp = torch.nn.functional.layer_norm(inp, (inp.shape[1],), None, None, 1e-6)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
del normed_inp, inp, vec_
if self.scheduler.step_index == 0 or self.scheduler.step_index == self.scheduler.infer_steps - 1:
should_calc = True
if self.scheduler.infer_condition:
self.accumulated_rel_l1_distance_odd = 0
self.previous_modulated_input_odd = modulated_inp
else:
self.accumulated_rel_l1_distance_even = 0
self.previous_modulated_input_even = modulated_inp
else:
rescale_func = np.poly1d(self.coefficients)
if self.scheduler.infer_condition:
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_modulated_input_odd).abs().mean() / self.previous_modulated_input_odd.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_modulated_input_odd = modulated_inp
else:
self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_modulated_input_even).abs().mean() / self.previous_modulated_input_even.abs().mean()).cpu().item()
)
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance_even = 0
self.previous_modulated_input_even = modulated_inp
del modulated_inp
return should_calc
def infer(self, weights, infer_module_out):
should_calc = self.calculate_should_calc(infer_module_out.img, infer_module_out.vec, weights.double_blocks[0])
if not should_calc:
if self.scheduler.infer_condition:
infer_module_out.img += self.previous_residual_odd
else:
infer_module_out.img += self.previous_residual_even
else:
ori_img = infer_module_out.img.clone()
self.infer_func(weights, infer_module_out)
if self.scheduler.infer_condition:
self.previous_residual_odd = infer_module_out.img - ori_img
else:
self.previous_residual_even = infer_module_out.img - ori_img
x = self.infer_final_layer(weights, infer_module_out)
return x
def clear(self):
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
if self.previous_modulated_input_odd is not None:
self.previous_modulated_input_odd = self.previous_modulated_input_odd.cpu()
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_modulated_input_even is not None:
self.previous_modulated_input_even = self.previous_modulated_input_even.cpu()
self.previous_modulated_input_odd = None
self.previous_residual_odd = None
self.previous_modulated_input_even = None
self.previous_residual_even = None
torch.cuda.empty_cache()
from dataclasses import dataclass
import torch
@dataclass
class HunyuanVideo15InferModuleOutput:
img: torch.Tensor
txt: torch.Tensor
vec: torch.Tensor
grid_sizes: tuple
@dataclass
class HunyuanVideo15ImgBranchOutput:
img_mod1_gate: torch.Tensor
img_mod2_shift: torch.Tensor
img_mod2_scale: torch.Tensor
img_mod2_gate: torch.Tensor
@dataclass
class HunyuanVideo15TxtBranchOutput:
txt_mod1_gate: torch.Tensor
txt_mod2_shift: torch.Tensor
txt_mod2_scale: torch.Tensor
txt_mod2_gate: torch.Tensor
import torch
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.hunyuan_video.infer.transformer_infer import HunyuanVideo15TransformerInfer
class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer):
def __init__(self, config):
super().__init__(config)
if self.config.get("cpu_offload", False):
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
self.infer_func = self.infer_with_blocks_offload
elif offload_granularity == "model":
self.infer_func = self.infer_without_offload
else:
raise NotImplementedError
if offload_granularity != "model":
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
@torch.no_grad()
def infer_with_blocks_offload(self, weights, infer_module_out):
for block_idx in range(self.double_blocks_num):
self.block_idx = block_idx
if block_idx == 0:
self.offload_manager.init_first_buffer(weights.double_blocks)
if block_idx < self.double_blocks_num - 1:
self.offload_manager.prefetch_weights(block_idx + 1, weights.double_blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
infer_module_out.img, infer_module_out.txt = self.infer_double_block(self.offload_manager.cuda_buffers[0], infer_module_out)
self.offload_manager.swap_blocks()
import torch
from lightx2v.utils.envs import *
class HunyuanVideo15PostInfer:
def __init__(self, config):
self.config = config
self.unpatchify_channels = config["out_channels"]
self.patch_size = config["patch_size"] # (1, 1, 1)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.no_grad()
def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes[0], pre_infer_out.grid_sizes[1], pre_infer_out.grid_sizes[2])
return x
def unpatchify(self, x, t, h, w):
"""
Unpatchify a tensorized input back to frame format.
Args:
x (Tensor): Input tensor of shape (N, T, patch_size**2 * C)
t (int): Number of time steps
h (int): Height in patch units
w (int): Width in patch units
Returns:
Tensor: Output tensor of shape (N, C, t * pt, h * ph, w * pw)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
x = x[:, : t * h * w] # remove padding from seq parallel
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
import math
from typing import Optional
import torch
from einops import rearrange
from lightx2v.utils.envs import *
from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, sage_attn_no_pad_v2
from .module_io import HunyuanVideo15InferModuleOutput
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
@torch.compiler.disable
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, drop_rate: float = 0.0, attn_mask: Optional[torch.Tensor] = None, causal: bool = False, attn_type: str = "flash_attn2"
) -> torch.Tensor:
"""
Compute attention using flash_attn_no_pad.
Args:
q: Query tensor of shape [B, L, H, D]
k: Key tensor of shape [B, L, H, D]
v: Value tensor of shape [B, L, H, D]
drop_rate: Dropout rate for attention weights.
attn_mask: Optional attention mask of shape [B, L].
causal: Whether to apply causal masking.
Returns:
Output tensor after attention of shape [B, L, H*D]
"""
qkv = torch.stack([q, k, v], dim=2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.bool()
if attn_type == "flash_attn2":
x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
elif attn_type == "flash_attn3":
x = flash_attn_no_pad_v3(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
elif attn_type == "sage_attn2":
x = sage_attn_no_pad_v2(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class HunyuanVideo15PreInfer:
def __init__(self, config):
self.config = config
self.patch_size = config["patch_size"]
self.heads_num = config["heads_num"]
self.frequency_embedding_size = 256
self.max_period = 10000
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.no_grad()
def infer(self, weights, inputs):
latents = self.scheduler.latents
grid_sizes_t, grid_sizes_h, grid_sizes_w = latents.shape[2:]
timesteps = self.scheduler.timesteps
t = timesteps[self.scheduler.step_index]
if self.scheduler.infer_condition:
txt, text_mask = inputs["text_encoder_output"]["context"][0], inputs["text_encoder_output"]["context"][1]
else:
txt, text_mask = inputs["text_encoder_output"]["context_null"][0], inputs["text_encoder_output"]["context_null"][1]
byt5_txt, byt5_text_mask = inputs["text_encoder_output"]["byt5_features"], inputs["text_encoder_output"]["byt5_masks"]
siglip_output, siglip_mask = inputs["image_encoder_output"]["siglip_output"], inputs["image_encoder_output"]["siglip_mask"]
txt = txt.to(torch.bfloat16)
if self.config["is_sr_running"]:
if t < 1000 * self.scheduler.noise_scale:
condition = self.scheduler.zero_condition
else:
condition = self.scheduler.condition
img = x = latent_model_input = torch.concat([latents, condition], dim=1)
else:
cond_latents_concat = self.scheduler.cond_latents_concat
mask_concat = self.scheduler.mask_concat
img = x = latent_model_input = torch.concat([latents, cond_latents_concat, mask_concat], dim=1)
img = img.to(torch.bfloat16)
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = None
img = weights.img_in.apply(img)
img = img.flatten(2).transpose(1, 2)
t_freq = self.timestep_embedding(t_expand, self.frequency_embedding_size, self.max_period).to(torch.bfloat16)
vec = weights.time_in_0.apply(t_freq)
vec = torch.nn.functional.silu(vec)
vec = weights.time_in_2.apply(vec)
if self.config["is_sr_running"]:
use_meanflow = self.config.get("video_super_resolution", {}).get("use_meanflow", False)
if use_meanflow:
if self.scheduler.step_index == len(timesteps) - 1:
timesteps_r = torch.tensor([0.0], device=latent_model_input.device)
else:
timesteps_r = timesteps[self.scheduler.step_index + 1]
timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
else:
timesteps_r = None
if timesteps_r is not None:
t_freq = self.timestep_embedding(timesteps_r, self.frequency_embedding_size, self.max_period).to(torch.bfloat16)
vec_res = weights.time_r_in_0.apply(t_freq)
vec_res = torch.nn.functional.silu(vec_res)
vec_res = weights.time_r_in_2.apply(vec_res)
vec = vec + vec_res
t_freq = self.timestep_embedding(t_expand, self.frequency_embedding_size, self.max_period).to(torch.bfloat16)
timestep_aware_representations = weights.txt_in_t_embedder_0.apply(t_freq)
timestep_aware_representations = torch.nn.functional.silu(timestep_aware_representations)
timestep_aware_representations = weights.txt_in_t_embedder_2.apply(timestep_aware_representations)
mask_float = text_mask.float().unsqueeze(-1)
context_aware_representations = (txt * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = context_aware_representations.to(torch.bfloat16)
context_aware_representations = weights.txt_in_c_embedder_0.apply(context_aware_representations)
context_aware_representations = torch.nn.functional.silu(context_aware_representations)
context_aware_representations = weights.txt_in_c_embedder_2.apply(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
out = weights.txt_in_input_embedder.apply(txt[0].to(torch.bfloat16))
txt = self.run_individual_token_refiner(weights, out, text_mask, c)
# TODO: 可以删除这段计算
txt = txt.unsqueeze(0)
txt = txt + weights.cond_type_embedding.apply(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long))
txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True)
siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=torch.device("cuda")))
txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask)
txt = txt[:, : text_mask.sum(), :]
return HunyuanVideo15InferModuleOutput(
img=img.contiguous(),
txt=txt.contiguous(),
vec=torch.nn.functional.silu(vec),
grid_sizes=(grid_sizes_t, grid_sizes_h, grid_sizes_w),
)
def run_individual_token_refiner(self, weights, out, mask, c):
mask = mask.clone().bool()
mask[:, 0] = True # Prevent attention weights from becoming NaN
for block in weights.individual_token_refiner: # block num = 2
gate_msa, gate_mlp = self.adaLN_modulation(block, c)
norm_x = block.norm1.apply(out.unsqueeze(0)).squeeze(0)
qkv = block.self_attn_qkv.apply(norm_x).unsqueeze(0)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = attention(q, k, v, attn_mask=mask, attn_type="flash_attn2").squeeze(0)
out = out + apply_gate(block.self_attn_proj.apply(attn).unsqueeze(0), gate_msa).squeeze(0)
tmp = block.mlp_fc1.apply(block.norm2.apply(out))
tmp = torch.nn.functional.silu(tmp)
tmp = block.mlp_fc2.apply(tmp)
out = out + apply_gate(tmp.unsqueeze(0), gate_mlp).squeeze(0)
return out
def adaLN_modulation(self, weights, c):
c = torch.nn.functional.silu(c)
gate_msa, gate_mlp = weights.adaLN_modulation.apply(c).chunk(2, dim=1)
return gate_msa, gate_mlp
def timestep_embedding(self, t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def reorder_txt_token(self, byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=False, is_reorder=True):
if is_reorder:
reorder_txt = []
reorder_mask = []
for i in range(text_mask.shape[0]):
byt5_text_mask_i = byt5_text_mask[i].bool()
text_mask_i = text_mask[i].bool()
byt5_txt_i = byt5_txt[i]
txt_i = txt[i]
if zero_feat:
# When using block mask with approximate computation, set pad to zero to reduce error
pad_byt5 = torch.zeros_like(byt5_txt_i[~byt5_text_mask_i])
pad_text = torch.zeros_like(txt_i[~text_mask_i])
reorder_txt_i = torch.cat([byt5_txt_i[byt5_text_mask_i], txt_i[text_mask_i], pad_byt5, pad_text], dim=0)
else:
reorder_txt_i = torch.cat([byt5_txt_i[byt5_text_mask_i], txt_i[text_mask_i], byt5_txt_i[~byt5_text_mask_i], txt_i[~text_mask_i]], dim=0)
reorder_mask_i = torch.cat([byt5_text_mask_i[byt5_text_mask_i], text_mask_i[text_mask_i], byt5_text_mask_i[~byt5_text_mask_i], text_mask_i[~text_mask_i]], dim=0)
reorder_txt.append(reorder_txt_i)
reorder_mask.append(reorder_mask_i)
reorder_txt = torch.stack(reorder_txt)
reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64)
else:
reorder_txt = torch.concat([byt5_txt, txt], dim=1)
reorder_mask = torch.concat([byt5_text_mask, text_mask], dim=1).to(dtype=torch.int64)
return reorder_txt, reorder_mask
from typing import Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from .module_io import HunyuanVideo15ImgBranchOutput, HunyuanVideo15TxtBranchOutput
from .triton_ops import fuse_scale_shift_kernel
def modulate(x, scale, shift):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def apply_hunyuan_rope_with_flashinfer(
xq: torch.Tensor,
xk: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, H, D = xq.shape
query = xq.reshape(B * L, H * D).contiguous()
key = xk.reshape(B * L, H * D).contiguous()
positions = torch.arange(B * L, device=xq.device, dtype=torch.long)
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=D,
cos_sin_cache=cos_sin_cache,
is_neox=False,
)
xq_out = query.view(B, L, H, D)
xk_out = key.view(B, L, H, D)
return xq_out, xk_out
def apply_hunyuan_rope_with_torch(
xq: torch.Tensor,
xk: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, H, D = xq.shape
cos = cos_sin_cache[:, : D // 2]
sin = cos_sin_cache[:, D // 2 :]
def _apply_rope(x: torch.Tensor) -> torch.Tensor:
x_flat = x.view(B * L, H, D)
x1 = x_flat[..., ::2]
x2 = x_flat[..., 1::2]
cos_ = cos.unsqueeze(1)
sin_ = sin.unsqueeze(1)
o1 = x1.float() * cos_ - x2.float() * sin_
o2 = x2.float() * cos_ + x1.float() * sin_
out = torch.empty_like(x_flat)
out[..., ::2] = o1
out[..., 1::2] = o2
return out.view(B, L, H, D)
xq_out = _apply_rope(xq)
xk_out = _apply_rope(xk)
return xq_out, xk_out
class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
self.double_blocks_num = config["mm_double_blocks_depth"]
self.heads_num = config["heads_num"]
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.infer_func = self.infer_without_offload
if self.config.get("modulate_type", "triton") == "triton":
self.modulate_func = fuse_scale_shift_kernel
else:
self.modulate_func = modulate
if self.config.get("rope_type", "flashinfer") == "flashinfer":
self.apply_rope_func = apply_hunyuan_rope_with_flashinfer
else:
self.apply_rope_func = apply_hunyuan_rope_with_torch
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.scheduler.transformer_infer = self
@torch.no_grad()
def infer(self, weights, infer_module_out):
self.infer_func(weights, infer_module_out)
x = self.infer_final_layer(weights, infer_module_out)
return x
@torch.no_grad()
def infer_without_offload(self, weights, infer_module_out):
for i in range(self.double_blocks_num):
infer_module_out.img, infer_module_out.txt = self.infer_double_block(weights.double_blocks[i], infer_module_out)
@torch.no_grad()
def infer_final_layer(self, weights, infer_module_out):
x = torch.cat((infer_module_out.img, infer_module_out.txt), 1)
img = x[:, : infer_module_out.img.shape[1], ...]
shift, scale = weights.final_layer.adaLN_modulation.apply(infer_module_out.vec).chunk(2, dim=1)
img = self.modulate_func(weights.final_layer.norm_final.apply(img.squeeze(0)), scale=scale, shift=shift).squeeze(0)
img = weights.final_layer.linear.apply(img)
return img.unsqueeze(0)
@torch.no_grad()
def infer_double_block(self, weights, infer_module_out):
img_q, img_k, img_v, img_branch_out = self._infer_img_branch_before_attn(weights, infer_module_out)
txt_q, txt_k, txt_v, txt_branch_out = self._infer_txt_branch_before_attn(weights, infer_module_out)
img_attn, txt_attn = self._infer_attn(weights, img_q, img_k, img_v, txt_q, txt_k, txt_v)
img = self._infer_img_branch_after_attn(weights, img_attn, infer_module_out.img, img_branch_out)
txt = self._infer_txt_branch_after_attn(weights, txt_attn, infer_module_out.txt, txt_branch_out)
return img, txt
@torch.no_grad()
def _infer_img_branch_before_attn(self, weights, infer_module_out):
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = weights.img_branch.img_mod.apply(infer_module_out.vec).chunk(6, dim=-1)
img_modulated = weights.img_branch.img_norm1.apply(infer_module_out.img.squeeze(0))
img_modulated = self.modulate_func(img_modulated, scale=img_mod1_scale, shift=img_mod1_shift).squeeze(0)
img_q = weights.img_branch.img_attn_q.apply(img_modulated)
img_k = weights.img_branch.img_attn_k.apply(img_modulated)
img_v = weights.img_branch.img_attn_v.apply(img_modulated)
img_q = rearrange(img_q, "L (H D) -> L H D", H=self.heads_num)
img_k = rearrange(img_k, "L (H D) -> L H D", H=self.heads_num)
img_v = rearrange(img_v, "L (H D) -> L H D", H=self.heads_num)
img_q = weights.img_branch.img_attn_q_norm.apply(img_q)
img_k = weights.img_branch.img_attn_k_norm.apply(img_k)
img_q, img_k = self.apply_rope_func(img_q.unsqueeze(0), img_k.unsqueeze(0), cos_sin_cache=self.scheduler.cos_sin)
return (
img_q,
img_k,
img_v.unsqueeze(0),
HunyuanVideo15ImgBranchOutput(
img_mod1_gate=img_mod1_gate,
img_mod2_shift=img_mod2_shift,
img_mod2_scale=img_mod2_scale,
img_mod2_gate=img_mod2_gate,
),
)
@torch.no_grad()
def _infer_txt_branch_before_attn(self, weights, infer_module_out):
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = weights.txt_branch.txt_mod.apply(infer_module_out.vec).chunk(6, dim=-1)
txt_modulated = weights.txt_branch.txt_norm1.apply(infer_module_out.txt.squeeze(0))
txt_modulated = self.modulate_func(txt_modulated, scale=txt_mod1_scale, shift=txt_mod1_shift).squeeze(0)
txt_q = weights.txt_branch.txt_attn_q.apply(txt_modulated)
txt_k = weights.txt_branch.txt_attn_k.apply(txt_modulated)
txt_v = weights.txt_branch.txt_attn_v.apply(txt_modulated)
txt_q = rearrange(txt_q, "L (H D) -> L H D", H=self.heads_num)
txt_k = rearrange(txt_k, "L (H D) -> L H D", H=self.heads_num)
txt_v = rearrange(txt_v, "L (H D) -> L H D", H=self.heads_num)
txt_q = weights.txt_branch.txt_attn_q_norm.apply(txt_q).to(txt_v)
txt_k = weights.txt_branch.txt_attn_k_norm.apply(txt_k).to(txt_v)
return (
txt_q.unsqueeze(0),
txt_k.unsqueeze(0),
txt_v.unsqueeze(0),
HunyuanVideo15TxtBranchOutput(
txt_mod1_gate=txt_mod1_gate,
txt_mod2_shift=txt_mod2_shift,
txt_mod2_scale=txt_mod2_scale,
txt_mod2_gate=txt_mod2_gate,
),
)
@torch.no_grad()
def _infer_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v):
img_seqlen = img_q.shape[1]
query = torch.cat([img_q, txt_q], dim=1)
key = torch.cat([img_k, txt_k], dim=1)
value = torch.cat([img_v, txt_v], dim=1)
seqlen = query.shape[1]
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to("cuda", non_blocking=True)
if self.config["seq_parallel"]:
attn_out = weights.self_attention_parallel.apply(
q=query,
k=key,
v=value,
img_qkv_len=img_seqlen,
cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=weights.self_attention,
seq_p_group=self.seq_p_group,
)
else:
attn_out = weights.self_attention.apply(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=seqlen,
max_seqlen_kv=seqlen,
)
img_attn, txt_attn = attn_out[:img_seqlen], attn_out[img_seqlen:]
return img_attn, txt_attn
@torch.no_grad()
def _infer_img_branch_after_attn(self, weights, img_attn, img, img_branch_out):
img = img + apply_gate(weights.img_branch.img_attn_proj.apply(img_attn).unsqueeze(0), gate=img_branch_out.img_mod1_gate)
out = weights.img_branch.img_mlp_fc1.apply(
self.modulate_func(weights.img_branch.img_norm2.apply(img.squeeze(0)), scale=img_branch_out.img_mod2_scale, shift=img_branch_out.img_mod2_shift).squeeze(0)
)
out = weights.img_branch.img_mlp_fc2.apply(F.gelu(out, approximate="tanh"))
img = img + apply_gate(out.unsqueeze(0), gate=img_branch_out.img_mod2_gate)
return img
@torch.no_grad()
def _infer_txt_branch_after_attn(self, weights, txt_attn, txt, txt_branch_out):
txt = txt + apply_gate(weights.txt_branch.txt_attn_proj.apply(txt_attn).unsqueeze(0), gate=txt_branch_out.txt_mod1_gate)
out = weights.txt_branch.txt_mlp_fc1.apply(
self.modulate_func(weights.txt_branch.txt_norm2.apply(txt.squeeze(0)), scale=txt_branch_out.txt_mod2_scale, shift=txt_branch_out.txt_mod2_shift).squeeze(0)
)
out = weights.txt_branch.txt_mlp_fc2.apply(F.gelu(out, approximate="tanh"))
txt = txt + apply_gate(out.unsqueeze(0), gate=txt_branch_out.txt_mod2_gate)
return txt
This diff is collapsed.
This diff is collapsed.
from lightx2v.common.modules.weight_module import WeightModule
class HunyuanVideo15PostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
This diff is collapsed.
...@@ -22,8 +22,9 @@ class WanAudioModel(WanModel): ...@@ -22,8 +22,9 @@ class WanAudioModel(WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.config = config self.config = config
super().__init__(model_path, config, device)
self._load_adapter_ckpt() self._load_adapter_ckpt()
self.run_device = self.config.get("run_device", "cuda")
super().__init__(model_path, config, device)
def _load_adapter_ckpt(self): def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None: if self.config.get("adapter_model_path", None) is None:
...@@ -50,7 +51,7 @@ class WanAudioModel(WanModel): ...@@ -50,7 +51,7 @@ class WanAudioModel(WanModel):
if not adapter_offload: if not adapter_offload:
if not dist.is_initialized() or not load_from_rank0: if not dist.is_initialized() or not load_from_rank0:
for key in self.adapter_weights_dict: for key in self.adapter_weights_dict:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(self.device)) self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(self.run_device))
def _init_infer_class(self): def _init_infer_class(self):
super()._init_infer_class() super()._init_infer_class()
......
...@@ -21,7 +21,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -21,7 +21,7 @@ class WanAudioPreInfer(WanPreInfer):
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], ],
dim=1, dim=1,
).to(self.device) ).to(torch.device(self.run_device))
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.rope_t_dim = d // 2 - 2 * (d // 6) self.rope_t_dim = d // 2 - 2 * (d // 6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment