module_io.py 873 Bytes
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""Module I/O definitions for LTX2 infer classes."""

from dataclasses import dataclass

import torch


@dataclass
class TransformerArgs:
    """Transformer arguments matching source TransformerArgs structure."""

    x: torch.Tensor | None = None
    context: torch.Tensor | None = None
    context_mask: torch.Tensor | None = None
    timesteps: torch.Tensor | None = None
    embedded_timestep: torch.Tensor | None = None
    positional_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None
    cross_positional_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None
    cross_scale_shift_timestep: torch.Tensor | None = None
    cross_gate_timestep: torch.Tensor | None = None


@dataclass
class LTX2PreInferModuleOutput:
    """Output from LTX2PreInfer module."""

    video_args: TransformerArgs | None = None
    audio_args: TransformerArgs | None = None