packed_seq_params.py 421 Bytes
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass

from torch import Tensor


@dataclass
class PackedSeqParams:
    # parameters to TEDotProductAttention and fused rope kernels for the `thd` (packed) sequence format,
    qkv_format: str = None
    cu_seqlens_q: Tensor = None
    cu_seqlens_kv: Tensor = None
    max_seqlen_q: Tensor = None
    max_seqlen_kv: Tensor = None