test_vllm.py 1.79 KB
Newer Older
baberabb's avatar
baberabb committed
1
from typing import List
2
3

import pytest
Nathan Habib's avatar
Nathan Habib committed
4
import torch
baberabb's avatar
baberabb committed
5

6
from lm_eval import tasks
7
8
from lm_eval.api.instance import Instance

baberabb's avatar
baberabb committed
9

10
11
12
task_manager = tasks.TaskManager()


baberabb's avatar
baberabb committed
13
@pytest.mark.skip(reason="requires CUDA")
Nathan Habib's avatar
Nathan Habib committed
14
class TEST_VLLM:
baberabb's avatar
baberabb committed
15
    vllm = pytest.importorskip("vllm")
baberabb's avatar
baberabb committed
16
17
18
19
20
21
    try:
        from lm_eval.models.vllm_causallms import VLLM

        LM = VLLM(pretrained="EleutherAI/pythia-70m")
    except ModuleNotFoundError:
        pass
Nathan Habib's avatar
Nathan Habib committed
22
    torch.use_deterministic_algorithms(True)
23
24
    task_list = task_manager.load_task_or_group(["arc_easy", "gsm8k", "wikitext"])
    multiple_choice_task = task_list["arc_easy"]  # type: ignore
baberabb's avatar
baberabb committed
25
26
    multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
    MULTIPLE_CH: List[Instance] = multiple_choice_task.instances
27
    generate_until_task = task_list["gsm8k"]  # type: ignore
baberabb's avatar
baberabb committed
28
    generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
29
    generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
baberabb's avatar
baberabb committed
30
    generate_until: List[Instance] = generate_until_task.instances
31
    rolling_task = task_list["wikitext"]  # type: ignore
baberabb's avatar
baberabb committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
    ROLLING: List[Instance] = rolling_task.instances

    # TODO: make proper tests
    def test_logliklihood(self) -> None:
        res = self.LM.loglikelihood(self.MULTIPLE_CH)
        assert len(res) == len(self.MULTIPLE_CH)
        for x in res:
            assert isinstance(x[0], float)

    def test_generate_until(self) -> None:
        res = self.LM.generate_until(self.generate_until)
        assert len(res) == len(self.generate_until)
        for x in res:
            assert isinstance(x, str)

    def test_logliklihood_rolling(self) -> None:
        res = self.LM.loglikelihood_rolling(self.ROLLING)
        for x in res:
            assert isinstance(x, float)