test_gemma.py 1.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest

MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"]


@pytest.mark.parametrize("model", MODELS)
10
11
12
13
14
15
16
17
def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
    with monkeypatch.context() as m:
        m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
        with vllm_runner(
                model,
                load_format="dummy",
        ) as llm:
            if model == "google/gemma-3-4b-it":
18
                normalizers = llm.llm.collective_rpc(
19
20
                    lambda self: self.model_runner.model.language_model.model.
                    normalizer.cpu().item())
21
                config = llm.llm.llm_engine.model_config.hf_config.text_config
22
            else:
23
                normalizers = llm.llm.collective_rpc(
24
25
                    lambda self: self.model_runner.model.model.normalizer.cpu(
                    ).item())
26
                config = llm.llm.llm_engine.model_config.hf_config
27
            assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3)