Commit 091a2a85 authored by sandy's avatar sandy Committed by GitHub
Browse files

[Feat] For Sekotalk Add Torch Compile (#294)

parent accbf710
...@@ -19,30 +19,79 @@ def linear_interpolation(features, output_len: int): ...@@ -19,30 +19,79 @@ def linear_interpolation(features, output_len: int):
return output_features.transpose(1, 2) return output_features.transpose(1, 2)
def get_q_lens_audio_range( @torch.compiler.disable
batchsize: int, def get_max_int(q_lens, k_lens):
n_tokens_per_rank: int, max_seqlen_q = int(q_lens.max().item())
n_query_tokens: int, max_seqlen_k = int(k_lens.max().item())
n_tokens_per_frame: int, return max_seqlen_q, max_seqlen_k
sp_rank: int,
def get_qk_lens_audio_range(
n_tokens_per_rank: torch.Tensor,
n_query_tokens: torch.Tensor,
n_tokens_per_frame: torch.Tensor,
sp_rank: torch.Tensor,
num_tokens_x4,
): ):
device = n_tokens_per_rank.device
dtype = torch.int32
if n_query_tokens == 0: if n_query_tokens == 0:
q_lens = [1] * batchsize q_lens = torch.ones(1, dtype=dtype, device=device)
return q_lens, 0, 1 t0 = torch.tensor(0, device=device)
t1 = torch.tensor(1, device=device)
k_lens = torch.full((t1 - t0,), num_tokens_x4, dtype=dtype, device=device)
max_seqlen_q, max_seqlen_k = get_max_int(q_lens, k_lens)
return q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1
idx0 = n_tokens_per_rank * sp_rank idx0 = n_tokens_per_rank * sp_rank
first_length = n_tokens_per_frame - idx0 % n_tokens_per_frame first_length = n_tokens_per_frame - idx0 % n_tokens_per_frame
first_length = min(first_length, n_query_tokens) first_length = torch.minimum(first_length, n_query_tokens)
n_frames = (n_query_tokens - first_length) // n_tokens_per_frame
n_frames = torch.div(n_query_tokens - first_length, n_tokens_per_frame, rounding_mode="floor")
last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length
q_lens = []
if first_length > 0: first_tensor = first_length.unsqueeze(0) # [1]
q_lens.append(first_length) frame_tensor = n_tokens_per_frame.repeat(n_frames) # [n_frames]
q_lens += [n_tokens_per_frame] * n_frames last_tensor = last_length.unsqueeze(0) # [1]
if last_length > 0:
q_lens.append(last_length) q_lens_all = torch.cat([first_tensor, frame_tensor, last_tensor])
q_lens = q_lens_all[q_lens_all > 0].to(dtype)
t0 = idx0 // n_tokens_per_frame t0 = idx0 // n_tokens_per_frame
t1 = t0 + len(q_lens) t1 = t0 + q_lens.numel()
return q_lens * batchsize, t0, t1
k_lens = torch.full((t1 - t0,), num_tokens_x4, dtype=dtype, device=device)
assert q_lens.shape == k_lens.shape
max_seqlen_q, max_seqlen_k = get_max_int(q_lens, k_lens)
return q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1
def calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank, n_tokens):
tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
val = n_tokens_per_rank - (tail_length % n_tokens_per_rank)
n_query_tokens = val
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
return n_query_tokens, hidden_states_aligned, hidden_states_tail
class PerceiverAttentionCA(nn.Module): class PerceiverAttentionCA(nn.Module):
...@@ -73,7 +122,7 @@ class PerceiverAttentionCA(nn.Module): ...@@ -73,7 +122,7 @@ class PerceiverAttentionCA(nn.Module):
shift_scale_gate[:, 2] = 1 shift_scale_gate[:, 2] = 1
self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False) self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False)
def forward(self, x, latents, t_emb, q_lens, k_lens): def forward(self, x, latents, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent, """x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)""" model_dim) latents (batchsize, length, model_dim)"""
batchsize = len(x) batchsize = len(x)
...@@ -90,14 +139,15 @@ class PerceiverAttentionCA(nn.Module): ...@@ -90,14 +139,15 @@ class PerceiverAttentionCA(nn.Module):
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads) q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads) k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads) v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads)
out = flash_attn.flash_attn_varlen_func( out = flash_attn.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=q_lens.max().item(), max_seqlen_q=max_seqlen_q,
max_seqlen_k=k_lens.max().item(), max_seqlen_k=max_seqlen_k,
dropout_p=0.0, dropout_p=0.0,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
......
...@@ -11,8 +11,11 @@ class WanAudioPostInfer(WanPostInfer): ...@@ -11,8 +11,11 @@ class WanAudioPostInfer(WanPostInfer):
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = x[: pre_infer_out.seq_lens[0]] x = x[: pre_infer_out.seq_lens[0]]
pre_infer_out.grid_sizes[:, 0] -= 1
x = self.unpatchify(x, pre_infer_out.grid_sizes) t, h, w = pre_infer_out.grid_sizes.tuple
grid_sizes = (t - 1, h, w)
x = self.unpatchify(x, grid_sizes)
if self.clean_cuda_cache: if self.clean_cuda_cache:
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..module_io import WanPreInferModuleOutput from ..module_io import GridOutput, WanPreInferModuleOutput
from ..utils import rope_params, sinusoidal_embedding_1d from ..utils import rope_params, sinusoidal_embedding_1d
...@@ -61,9 +61,9 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -61,9 +61,9 @@ class WanAudioPreInfer(WanPreInfer):
# embeddings # embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0)) x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0) grid_sizes = torch.tensor(x.shape[2:], dtype=torch.int32, device=x.device).unsqueeze(0)
x = x.flatten(2).transpose(1, 2).contiguous() x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0) seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
y = weights.patch_embedding.apply(y.unsqueeze(0)) y = weights.patch_embedding.apply(y.unsqueeze(0))
y = y.flatten(2).transpose(1, 2).contiguous() y = y.flatten(2).transpose(1, 2).contiguous()
...@@ -114,6 +114,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -114,6 +114,7 @@ class WanAudioPreInfer(WanPreInfer):
del context_clip del context_clip
torch.cuda.empty_cache() torch.cuda.empty_cache()
grid_sizes = GridOutput(tensor=grid_sizes, tuple=(grid_sizes[0][0].item(), grid_sizes[0][1].item(), grid_sizes[0][2].item()))
return WanPreInferModuleOutput( return WanPreInferModuleOutput(
embed=embed, embed=embed,
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import get_q_lens_audio_range from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import calculate_n_query_tokens, get_qk_lens_audio_range
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
...@@ -18,11 +18,9 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -18,11 +18,9 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def post_process(self, x, y, c_gate_msa, pre_infer_out): def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out) x = super().post_process(x, y, c_gate_msa, pre_infer_out)
audio_grid_sizes = [row.clone() for row in pre_infer_out.grid_sizes]
audio_grid_sizes[0][0] -= 1
x = self.modify_hidden_states( x = self.modify_hidden_states(
hidden_states=x.to(self.infer_dtype), hidden_states=x.to(self.infer_dtype),
grid_sizes=audio_grid_sizes, grid_sizes=pre_infer_out.grid_sizes.tensor,
ca_block=self.audio_adapter.ca[self.block_idx], ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"], audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
t_emb=self.scheduler.audio_adapter_t_emb, t_emb=self.scheduler.audio_adapter_t_emb,
...@@ -41,11 +39,14 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -41,11 +39,14 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
""" """
if len(hidden_states.shape) == 2: # 扩展batchsize dim if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1 hidden_states = hidden_states.unsqueeze(0) # bs = 1
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w total_tokens = grid_sizes[0].prod()
pre_frame_tokens = grid_sizes[0][1:].prod()
n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数
ori_dtype = hidden_states.dtype ori_dtype = hidden_states.dtype
device = hidden_states.device device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2] n_tokens_per_rank = torch.tensor(hidden_states.size(1), dtype=torch.int32, device=device)
if seq_p_group is not None: if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group) sp_size = dist.get_world_size(seq_p_group)
...@@ -54,35 +55,15 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -54,35 +55,15 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
sp_size = 1 sp_size = 1
sp_rank = 0 sp_rank = 0
tail_length = n_tokens_per_rank * sp_size - n_tokens n_query_tokens, hidden_states_aligned, hidden_states_tail = calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank, n_tokens)
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank) q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 = get_qk_lens_audio_range(
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32) n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=self.num_tokens_x4
""" )
processing audio features in sp_state can be moved outside.
"""
audio_encoder_output = audio_encoder_output[:, t0:t1]
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数 # ca_block:CrossAttention函数
if self.audio_adapter.cpu_offload: if self.audio_adapter.cpu_offload:
ca_block.to("cuda") ca_block.to("cuda")
residual = ca_block(audio_encoder_output, hidden_states_aligned, t_emb, q_lens, k_lens) * weight residual = ca_block(audio_encoder_output[:, t0:t1], hidden_states_aligned, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k) * weight
if self.audio_adapter.cpu_offload: if self.audio_adapter.cpu_offload:
ca_block.to("cpu") ca_block.to("cpu")
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入 residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
......
...@@ -4,10 +4,16 @@ from typing import Any, Dict ...@@ -4,10 +4,16 @@ from typing import Any, Dict
import torch import torch
@dataclass
class GridOutput:
tensor: torch.Tensor
tuple: tuple
@dataclass @dataclass
class WanPreInferModuleOutput: class WanPreInferModuleOutput:
embed: torch.Tensor embed: torch.Tensor
grid_sizes: torch.Tensor grid_sizes: GridOutput
x: torch.Tensor x: torch.Tensor
embed0: torch.Tensor embed0: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
......
...@@ -188,7 +188,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -188,7 +188,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0) ) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0)
self.phase_params["y_out"] = self.infer_self_attn( self.phase_params["y_out"] = self.infer_self_attn(
cur_phase, cur_phase,
pre_infer_out.grid_sizes, pre_infer_out.grid_sizes.tuple,
x, x,
pre_infer_out.seq_lens, pre_infer_out.seq_lens,
pre_infer_out.freqs, pre_infer_out.freqs,
......
...@@ -15,7 +15,7 @@ class WanPostInfer: ...@@ -15,7 +15,7 @@ class WanPostInfer:
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes) x = self.unpatchify(x, pre_infer_out.grid_sizes.tuple)
if self.clean_cuda_cache: if self.clean_cuda_cache:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -23,12 +23,8 @@ class WanPostInfer: ...@@ -23,12 +23,8 @@ class WanPostInfer:
return [u.float() for u in x] return [u.float() for u in x]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
x = x.unsqueeze(0)
c = self.out_dim c = self.out_dim
out = [] x = x[: math.prod(grid_sizes)].view(*grid_sizes, *self.patch_size, c)
for u, v in zip(x, grid_sizes.tolist()): x = torch.einsum("fhwpqrc->cfphqwr", x)
u = u[: math.prod(v)].view(*v, *self.patch_size, c) x = x.reshape(c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
u = torch.einsum("fhwpqrc->cfphqwr", u) return [x]
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .module_io import WanPreInferModuleOutput from .module_io import GridOutput, WanPreInferModuleOutput
from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d
...@@ -61,13 +61,13 @@ class WanPreInfer: ...@@ -61,13 +61,13 @@ class WanPreInfer:
# embeddings # embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0)) x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0) grid_sizes = torch.tensor(x.shape[2:], dtype=torch.int32, device=x.device).unsqueeze(0)
x = x.flatten(2).transpose(1, 2).contiguous() x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0) seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg: if self.enable_dynamic_cfg:
s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device) s = torch.tensor([self.cfg_scale], dtype=torch.float32, device=x.device)
cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32).type_as(x) cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32).type_as(x)
cfg_embed = weights.cfg_cond_proj_1.apply(cfg_embed) cfg_embed = weights.cfg_cond_proj_1.apply(cfg_embed)
cfg_embed = torch.nn.functional.silu(cfg_embed) cfg_embed = torch.nn.functional.silu(cfg_embed)
...@@ -117,6 +117,7 @@ class WanPreInfer: ...@@ -117,6 +117,7 @@ class WanPreInfer:
del context_clip del context_clip
torch.cuda.empty_cache() torch.cuda.empty_cache()
grid_sizes = GridOutput(tensor=grid_sizes, tuple=(grid_sizes[0][0].item(), grid_sizes[0][1].item(), grid_sizes[0][2].item()))
return WanPreInferModuleOutput( return WanPreInferModuleOutput(
embed=embed, embed=embed,
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
......
...@@ -96,7 +96,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -96,7 +96,7 @@ class WanTransformerInfer(BaseTransformerInfer):
) )
y_out = self.infer_self_attn( y_out = self.infer_self_attn(
block.compute_phases[0], block.compute_phases[0],
pre_infer_out.grid_sizes, pre_infer_out.grid_sizes.tuple,
x, x,
pre_infer_out.seq_lens, pre_infer_out.seq_lens,
pre_infer_out.freqs, pre_infer_out.freqs,
......
...@@ -6,7 +6,7 @@ from lightx2v.utils.envs import * ...@@ -6,7 +6,7 @@ from lightx2v.utils.envs import *
def compute_freqs(c, grid_sizes, freqs): def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0] f, h, w = grid_sizes
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
...@@ -24,7 +24,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group): ...@@ -24,7 +24,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group) world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group) cur_rank = dist.get_rank(seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0] f, h, w = grid_sizes
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
...@@ -43,7 +43,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group): ...@@ -43,7 +43,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0] f, h, w = grid_sizes
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment