"vscode:/vscode.git/clone" did not exist on "7940d8a6a7841113c89b168080b51785946f0cdb"
test_sharded_state_loader.py 2.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import shutil
from tempfile import TemporaryDirectory

import pytest
import torch
from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
from vllm.model_executor.model_loader.loader import ShardedStateLoader

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    seed=0,
    max_tokens=256,
    ignore_eos=True,
)


def test_filter_subtensors():
    state_dict = {
        "a": torch.empty(2),
        "b": torch.empty((2, 4)),
        "c": torch.empty((2, 4, 8)),
    }
    state_dict.update({
        "x": state_dict["b"],
        "y": state_dict["c"][1, 2, :],
        "z": state_dict["c"][1, :, 4],
    })
    filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
    assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
    for key, tensor in filtered_state_dict.items():
        assert tensor.equal(state_dict[key])


@pytest.mark.parametrize("enable_lora", [False, True])
def test_sharded_state_loader(enable_lora):
    weights_patterns = ("*.bin", "*.pt", "*.safetensors")

    with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir:
        input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
                                      cache_dir=cache_dir)

        llm = LLM(
            model=input_dir,
            worker_use_ray=True,
            gpu_memory_utilization=0.3,
        )

        # Dump worker states to output directory
        model_executor = llm.llm_engine.model_executor
        model_executor.save_sharded_state(path=output_dir)
        # Copy metadata files to output directory
        for file in os.listdir(input_dir):
            if not any(file.endswith(ext) for ext in weights_patterns):
                shutil.copy(f"{input_dir}/{file}", output_dir)
        del llm.llm_engine.model_executor

        llm_before = LLM(
            model=input_dir,
            worker_use_ray=True,
            enable_lora=enable_lora,
            gpu_memory_utilization=0.3,
        )
        gen_before = llm_before.generate(prompts, sampling_params)
        out_before = [gen.outputs[0].__dict__ for gen in gen_before]
        del llm_before.llm_engine.model_executor

        llm_after = LLM(
            model=output_dir,
            worker_use_ray=True,
            enable_lora=enable_lora,
            gpu_memory_utilization=0.3,
            load_format="sharded_state",
        )
        gen_after = llm_after.generate(prompts, sampling_params)
        out_after = [gen.outputs[0].__dict__ for gen in gen_after]
        del llm_after.llm_engine.model_executor

        assert out_before == out_after