utils.py 1.62 KB
Newer Older
1
2
"""Block manager utils."""
from vllm.sequence import SequenceGroup
3
4
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
                        STR_NOT_IMPL_ENC_DEC_SWA)
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


def _get_block_mgr_sliding_window_attr(block_mgr):
    '''
    BlockManagerV1 and BlockManagerV2 have slightly different
    members related to sliding window attention (SWA). This
    function extracts the appropriate member to use for determining
    whether SWA is enabled.

    Arguments:

    * block_mgr: BlockManagerV1 or BlockManagerV2 instance
    '''

    if hasattr(block_mgr, 'block_sliding_window'):
        return block_mgr.block_sliding_window
    if hasattr(block_mgr, 'max_block_sliding_window'):
        return block_mgr.max_block_sliding_window

    raise AttributeError("Block manager instance has neither " + \
                         "block_sliding_window nor " + \
                         "max_block_sliding_window attributes.")


def check_no_caching_or_swa_for_blockmgr_encdec(
        block_mgr, seq_group: SequenceGroup) -> None:
    '''
    Enforce that prefix caching & sliding-window attention (SWA)
    are currently unsupported *specifically* for encoder/decoder models.

    Raises NotImplementedError if unsupported scenario is detected.

    Arguments:

    * block_mgr: BlockSpaceManager instance
    * seq_group: SequenceGroup passed to block_mgr
    '''

    if seq_group.is_encoder_decoder():
        if _get_block_mgr_sliding_window_attr(block_mgr) is not None:
            raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA)

        if block_mgr.enable_caching:
            raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE)