module_io.py 610 Bytes
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import torch


@dataclass
class GridOutput:
    tensor: torch.Tensor
    tuple: tuple


@dataclass
class WanPreInferModuleOutput:
    embed: torch.Tensor
    grid_sizes: GridOutput
    x: torch.Tensor
    embed0: torch.Tensor
    context: torch.Tensor
    # 3D RoPE / position related
    cos_sin: Optional[torch.Tensor] = None
    valid_token_len: int = 0
    valid_latent_num: int = 0
    # extra
    adapter_args: Dict[str, Any] = field(default_factory=dict)
    conditional_dict: Dict[str, Any] = field(default_factory=dict)