packed_seq_params.py 358 Bytes
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
from dataclasses import dataclass

from torch import Tensor


@dataclass
class PackedSeqParams:
    # parameters to TEDotProductAttention and fused rope kernels for the `thd` (packed) sequence format,
    qkv_format: str = None
    cu_seqlens_q: Tensor = None
    cu_seqlens_kv: Tensor = None
    max_seqlen_q: Tensor = None
    max_seqlen_kv: Tensor = None