coefficient_estimator.py 7.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import types
from typing import Any

import numpy as np
import torch
from vllm.config import LoadConfig

from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline


class DataCollectionHook(ModelHook):
    """Hook to collect modulated inputs and model outputs for TeaCache coefficient estimation."""

    _HOOK_NAME = "teacache_collector"

    def __init__(self, transformer_type: str):
        super().__init__()
        self.transformer_type = transformer_type
        self.extractor_fn = None
        self.current_trajectory: list[tuple[np.ndarray, np.ndarray]] = []

    def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
        self.extractor_fn = get_extractor(self.transformer_type)
        return module

    def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
        ctx = self.extractor_fn(module, *args, **kwargs)
        modulated_input_cpu = ctx.modulated_input.detach().cpu().numpy()

        outputs = ctx.run_transformer_blocks()
        ctx.hidden_states = outputs[0]
        if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
            ctx.encoder_hidden_states = outputs[1]

        model_output_cpu = ctx.hidden_states.detach().cpu().numpy()
        self.current_trajectory.append((modulated_input_cpu, model_output_cpu))

        return ctx.postprocess(ctx.hidden_states)

    def start_collection(self):
        self.current_trajectory = []

    def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]:
        return list(self.current_trajectory)


class BagelAdapter:
    """Adapter for Bagel model."""

    @staticmethod
    def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline:
        od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
        od_config.model_class_name = "BagelPipeline"

        pipeline = BagelPipeline(od_config=od_config)
        loader = DiffusersPipelineLoader(LoadConfig())
        loader.load_weights(pipeline)
        pipeline.to(device)
        return pipeline

    @staticmethod
    def get_transformer(pipeline: Any) -> tuple[Any, str]:
        return pipeline.bagel, "Bagel"

    @staticmethod
    def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
        original_forward_flow = transformer._forward_flow

        def forward_alias(self, *args, **kwargs):
            return original_forward_flow(*args, **kwargs)

        transformer.forward = types.MethodType(forward_alias, transformer)
        registry = HookRegistry.get_or_create(transformer)
        registry.register_hook(hook._HOOK_NAME, hook)
        transformer._forward_flow = transformer.forward


class DefaultAdapter:
    """Default adapter for standard diffusers pipelines."""

    @staticmethod
    def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
        raise NotImplementedError("DefaultAdapter.load_pipeline not implemented")

    @staticmethod
    def get_transformer(pipeline: Any) -> tuple[Any, str]:
        return pipeline.transformer, pipeline.transformer.__class__.__name__

    @staticmethod
    def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
        registry = HookRegistry.get_or_create(transformer)
        registry.register_hook(hook._HOOK_NAME, hook)


_MODEL_ADAPTERS: dict[str, type] = {
    "Bagel": BagelAdapter,
}

_EPSILON = 1e-6


def calculate_relative_l1(tensor_current: np.ndarray, tensor_next: np.ndarray) -> float:
    """Calculate relative L1 distance (Eq. 4 from TeaCache paper)."""
    diff = np.abs(tensor_current - tensor_next).sum()
    norm = np.abs(tensor_current).sum() + _EPSILON
    return diff / norm


def estimate_teacache_coefficients(
    collected_data: list[list[tuple[np.ndarray, np.ndarray]]], poly_order: int = 4
) -> list[float]:
    """Estimate polynomial coefficients for TeaCache using np.polyfit."""
    input_diffs, output_diffs = [], []

    for sample in collected_data:
        for t in range(len(sample) - 1):
            feat_in_curr, feat_out_curr = sample[t]
            feat_in_next, feat_out_next = sample[t + 1]
            input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next))
            output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next))

    x = np.array(input_diffs, dtype=np.float64)
    y = np.array(output_diffs, dtype=np.float64)

    print("Data statistics:")
    print(f"  Count: {len(x)}")
    print(f"  Input Diffs (x): min={x.min():.4e}, max={x.max():.4e}, mean={x.mean():.4e}")
    print(f"  Output Diffs (y): min={y.min():.4e}, max={y.max():.4e}, mean={y.mean():.4e}")

    return np.polyfit(x, y, poly_order).tolist()


class TeaCacheCoefficientEstimator:
    """Model-agnostic helper class to collect data and estimate TeaCache coefficients."""

    def __init__(
        self,
        model_path: str,
        model_type: str = "Bagel",
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
    ):
        # Add validation here ⬇️
        if model_type not in _MODEL_ADAPTERS:
            available_types = list(_MODEL_ADAPTERS.keys())
            raise ValueError(
                f"Unsupported model_type: '{model_type}'. "
                f"Available types: {available_types}. "
                f"To add support for a new model, add an entry to _MODEL_ADAPTERS."
            )

        adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter)
        self.pipeline = adapter.load_pipeline(model_path, device, dtype)
        self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
        self.hook = DataCollectionHook(self.transformer_type)
        self.collected_data: list[list[tuple[np.ndarray, np.ndarray]]] = []
        adapter.install_hook(self.transformer, self.hook)

    def collect_from_prompt(self, prompt: str, **generate_kwargs):
        self.hook.start_collection()
        from vllm_omni.diffusion.request import OmniDiffusionRequest

        req = OmniDiffusionRequest(
            prompt=prompt,
            num_inference_steps=generate_kwargs.get("num_inference_steps", 20),
            seed=generate_kwargs.get("seed", 42),
        )
        self.pipeline.forward(req)
        trajectory = self.hook.stop_collection()
        if trajectory:
            self.collected_data.append(trajectory)

    def estimate(self, poly_order: int = 4) -> list[float]:
        """Estimate polynomial coefficients from collected data.

        Args:
            poly_order: Order of polynomial fit (default: 4)

        Returns:
            List of polynomial coefficients [a_n, a_{n-1}, ..., a_1, a_0]

        Raises:
            RuntimeError: If no data has been collected
        """
        if not self.collected_data:
            raise RuntimeError(
                "No data collected for coefficient estimation. "
                "Call collect_from_prompt() at least once before calling estimate()."
            )
        return estimate_teacache_coefficients(self.collected_data, poly_order)