test_multimodal_config.py 853 Bytes
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

6
from vllm.attention.backends.registry import AttentionBackendEnum
7
8
9
10
11
from vllm.config.multimodal import MultiModalConfig


def test_mm_encoder_attn_backend_str_conversion():
    config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
12
    assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
13
14
15
16
17
18
19
20
21
22


def test_mm_encoder_attn_backend_invalid():
    with pytest.raises(ValueError):
        MultiModalConfig(mm_encoder_attn_backend="not_a_backend")


def test_mm_encoder_attn_backend_hash_updates():
    base_hash = MultiModalConfig().compute_hash()
    overridden_hash = MultiModalConfig(
23
        mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
24
25
    ).compute_hash()
    assert base_hash != overridden_hash