module_io.py 449 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from dataclasses import dataclass
gushiqiao's avatar
gushiqiao committed
2
from typing import Any, List, Optional
helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
10
11
12
13
14
15

import torch


@dataclass
class WanPreInferModuleOutput:
    embed: torch.Tensor
    grid_sizes: torch.Tensor
    x: torch.Tensor
    embed0: torch.Tensor
    seq_lens: torch.Tensor
    freqs: torch.Tensor
    context: torch.Tensor
gushiqiao's avatar
gushiqiao committed
16
17
18
19
    audio_dit_blocks: List[Any] = None
    valid_patch_length: Optional[int] = None
    hints: List[Any] = None
    context_scale: float = 1.0