state.py 1.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
TeaCache state management.

This module manages the state for TeaCache hooks across diffusion timesteps.
"""

import torch


class TeaCacheState:
    """
    State management for TeaCache hook.

    Tracks caching state across diffusion timesteps, managing counters,
    accumulated distances, and cached residuals for the TeaCache algorithm.
    """

    def __init__(self):
        """Initialize empty TeaCache state."""
        # Timestep tracking
        self.cnt = 0

        # Caching state
        self.accumulated_rel_l1_distance = 0.0
        self.previous_modulated_input: torch.Tensor | None = None
        self.previous_residual: torch.Tensor | None = None
        self.previous_residual_encoder: torch.Tensor | None = None

    def reset(self) -> None:
        """Reset all state variables for a new inference run."""
        self.cnt = 0
        self.accumulated_rel_l1_distance = 0.0
        self.previous_modulated_input = None
        self.previous_residual = None
        self.previous_residual_encoder = None