forward_context.py 1.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from typing import Any, Optional

import torch

_PROMPT_PAYLOAD_ATTR = "_kv_compression_prompt_payload"
_COMPACT_SLOTS_ATTR = "_kv_compression_compact_slots"
_COMPACT_SLOTS_BY_LAYER_ATTR = "_kv_compression_compact_slots_by_layer"


def get_kv_compression_prompt_payload(
    forward_context: Any,
) -> Optional[dict[str, torch.Tensor]]:
    return getattr(forward_context, _PROMPT_PAYLOAD_ATTR, None)


def set_kv_compression_prompt_payload(
    forward_context: Any,
    payload: dict[str, torch.Tensor],
) -> None:
    setattr(forward_context, _PROMPT_PAYLOAD_ATTR, payload)


def _kv_compression_layer_key(layer: Any) -> str:
    layer_name = getattr(layer, "layer_name", None)
    if layer_name is None:
        layer_name = str(id(layer))
    return str(layer_name)


def get_kv_compression_compact_slots(
    forward_context: Any,
    *,
    per_layer_topk: bool,
    layer: Any,
) -> Optional[torch.Tensor]:
    if per_layer_topk:
        dst_by_layer = getattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR,
                               None)
        if dst_by_layer is None:
            return None
        return dst_by_layer.get(_kv_compression_layer_key(layer))
    return getattr(forward_context, _COMPACT_SLOTS_ATTR, None)


def set_kv_compression_compact_slots(
    forward_context: Any,
    *,
    per_layer_topk: bool,
    layer: Any,
    dst: torch.Tensor,
) -> None:
    if per_layer_topk:
        dst_by_layer = getattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR,
                               None)
        if dst_by_layer is None:
            dst_by_layer = {}
            setattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR, dst_by_layer)
        dst_by_layer[_kv_compression_layer_key(layer)] = dst
    else:
        setattr(forward_context, _COMPACT_SLOTS_ATTR, dst)