module_io.py 384 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from dataclasses import dataclass
helloyongyang's avatar
helloyongyang committed
2
from typing import Any, Dict
helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8

import torch


@dataclass
class WanPreInferModuleOutput:
helloyongyang's avatar
helloyongyang committed
9
    # wan base model
helloyongyang's avatar
helloyongyang committed
10
11
12
13
14
15
16
    embed: torch.Tensor
    grid_sizes: torch.Tensor
    x: torch.Tensor
    embed0: torch.Tensor
    seq_lens: torch.Tensor
    freqs: torch.Tensor
    context: torch.Tensor
helloyongyang's avatar
helloyongyang committed
17
18
19

    # wan adapter model
    adapter_output: Dict[str, Any] = None