abstract.py 7.16 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py

from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, Optional, Protocol, Set,
                    Type, TypeVar)

if TYPE_CHECKING:
    from fastvideo.v1.fastvideo_args import FastVideoArgs
    from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch

import torch


class AttentionBackend(ABC):
    """Abstract class for attention backends."""
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        raise NotImplementedError

    # @staticmethod
    # @abstractmethod
    # def get_state_cls() -> Type["AttentionState"]:
    #     raise NotImplementedError

    # @classmethod
    # def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
    #     return cls.get_metadata_cls()(*args, **kwargs)

    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError


@dataclass
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Current step of diffusion process
    current_timestep: int

    # @property
    # @abstractmethod
    # def inference_metadata(self) -> Optional["AttentionMetadata"]:
    #     """Return the attention metadata that's required to run prefill
    #     attention."""
    #     pass

    # @property
    # @abstractmethod
    # def training_metadata(self) -> Optional["AttentionMetadata"]:
    #     """Return the attention metadata that's required to run decode
    #     attention."""
    #     pass

    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
        """Similar to dataclasses.asdict, but avoids deepcopying."""
        if skip_fields is None:
            skip_fields = set()
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
            for field in fields(self) if field.name not in skip_fields
        }


T = TypeVar("T", bound=AttentionMetadata)

# class AttentionState(ABC, Generic[T]):
#     """Holds attention backend-specific objects reused during the
#     lifetime of the model runner."""

#     @abstractmethod
#     def __init__(self, runner: "ModelRunnerBase"):
#         ...

#     @abstractmethod
#     @contextmanager
#     def graph_capture(self, max_batch_size: int):
#         """Context manager used when capturing CUDA graphs."""
#         yield

#     @abstractmethod
#     def graph_clone(self, batch_size: int) -> "AttentionState[T]":
#         """Clone attention state to save in CUDA graph metadata."""
#         ...

#     @abstractmethod
#     def graph_capture_get_metadata_for_batch(
#             self,
#             batch_size: int,
#             is_encoder_decoder_model: bool = False) -> T:
#         """Get attention metadata for CUDA graph capture of batch_size."""
#         ...

#     @abstractmethod
#     def get_graph_input_buffers(
#             self,
#             attn_metadata: T,
#             is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
#         """Get attention-specific input buffers for CUDA graph capture."""
#         ...

#     @abstractmethod
#     def prepare_graph_input_buffers(
#             self,
#             input_buffers: Dict[str, Any],
#             attn_metadata: T,
#             is_encoder_decoder_model: bool = False) -> None:
#         """In-place modify input buffers dict for CUDA graph replay."""
#         ...

#     @abstractmethod
#     def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
#         """Prepare state for forward pass."""
#         ...


class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
    def __init__(self) -> None:
        """Create the builder, remember some configuration and parameters."""
        raise NotImplementedError

    @abstractmethod
    def prepare(self) -> None:
        """Prepare for one batch."""
        raise NotImplementedError

    @abstractmethod
    def build(
        self,
        current_timestep: int,
        forward_batch: "ForwardBatch",
        fastvideo_args: "FastVideoArgs",
    ) -> T:
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


class AttentionLayer(Protocol):

    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
    _k_scale_float: float
    _v_scale_float: float

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        ...


class AttentionImpl(ABC, Generic[T]):

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        softmax_scale: float,
        causal: bool = False,
        num_kv_heads: Optional[int] = None,
        prefix: str = "",
        **extra_impl_args,
    ) -> None:
        raise NotImplementedError

    def preprocess_qkv(self, qkv: torch.Tensor,
                       attn_metadata: T) -> torch.Tensor:
        """Preprocess QKV tensor before performing attention operation.

        Default implementation returns the tensor unchanged.
        Subclasses can override this to implement custom preprocessing
        like reshaping, tiling, scaling, or other transformations.

        Called AFTER all_to_all for distributed attention
        
        Args:
            qkv: The query-key-value tensor
            attn_metadata: Metadata for the attention operation
            
        Returns:
            Processed QKV tensor
        """
        return qkv

    def postprocess_output(
        self,
        output: torch.Tensor,
        attn_metadata: T,
    ) -> torch.Tensor:
        """Postprocess the output tensor after the attention operation.

        Default implementation returns the tensor unchanged.
        Subclasses can override this to implement custom postprocessing
        like untiling, scaling, or other transformations.

        Called BEFORE all_to_all for distributed attention

        Args:
            output: The output tensor from the attention operation
            attn_metadata: Metadata for the attention operation

        Returns:
            Postprocessed output tensor
        """

        return output

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_metadata: T,
    ) -> torch.Tensor:
        raise NotImplementedError