module_io.py 439 Bytes
Newer Older
gushiqiao's avatar
gushiqiao committed
1
from dataclasses import dataclass, field
helloyongyang's avatar
helloyongyang committed
2
from typing import Any, Dict
helloyongyang's avatar
helloyongyang committed
3
4
5
6

import torch


7
8
9
10
11
12
@dataclass
class GridOutput:
    tensor: torch.Tensor
    tuple: tuple


helloyongyang's avatar
helloyongyang committed
13
14
15
@dataclass
class WanPreInferModuleOutput:
    embed: torch.Tensor
16
    grid_sizes: GridOutput
helloyongyang's avatar
helloyongyang committed
17
18
19
20
21
    x: torch.Tensor
    embed0: torch.Tensor
    seq_lens: torch.Tensor
    freqs: torch.Tensor
    context: torch.Tensor
gushiqiao's avatar
gushiqiao committed
22
    adapter_output: Dict[str, Any] = field(default_factory=dict)