test_get_parameter_by_name.py 4.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import gc
import unittest

import numpy as np
import requests
import torch
from transformers import AutoModelForCausalLM

import sglang as sgl
from sglang.test.test_utils import (
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)
from sglang.utils import terminate_process


Lianmin Zheng's avatar
Lianmin Zheng committed
19
class TestGetParameterByName(unittest.TestCase):
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.hf_model = AutoModelForCausalLM.from_pretrained(
            cls.model, torch_dtype="bfloat16"
        ).to("cuda:0")

    @classmethod
    def tearDownClass(cls):
        del cls.hf_model
        gc.collect()
        torch.cuda.empty_cache()

    def init_backend(self, backend, dp, tp):
        self.engine = None
        self.process = None
        self.backend = backend
        self.dp = dp
        self.tp = tp
        if backend == "Engine":
            self.engine = sgl.Engine(
                model_path=self.model,
                random_seed=42,
                tp_size=self.tp,
                dp_size=self.dp,
                mem_fraction_static=0.85,
            )
        else:
            self.process = popen_launch_server(
                self.model,
                self.base_url,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=(
                    "--tp-size",
                    str(tp),
                    "--dp-size",
                    str(dp),
                ),
            )

    def close_engine_and_server(self):
        if self.engine:
            self.engine.shutdown()
        if self.process:
            terminate_process(self.process)

Lianmin Zheng's avatar
Lianmin Zheng committed
67
    def assert_weights_all_close(self, param_name, truncate_size):
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        print(
            f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
        )
        param = self.hf_model.get_parameter(param_name)[:truncate_size]
        param_np = param.cpu().detach().float().numpy()

        if self.backend == "Engine":
            engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
            engine_ret = self._process_return(engine_ret)
            np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)

        if self.backend == "Runtime":
            runtime_ret = requests.get(
                f"{self.base_url}/get_weights_by_name",
                json={"name": param_name, "truncate_size": truncate_size},
            ).json()
            runtime_ret = self._process_return(runtime_ret)
            np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)

    @staticmethod
    def _process_return(ret):
        if isinstance(ret, list) and len(ret) == 2:
Lianmin Zheng's avatar
Lianmin Zheng committed
90
            print("running assert_allclose on data parallel")
91
92
93
94
            np.testing.assert_allclose(ret[0], ret[1])
            return np.array(ret[0])
        return np.array(ret)

Lianmin Zheng's avatar
Lianmin Zheng committed
95
    def test_get_parameters_by_name(self):
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)]

        if torch.cuda.device_count() >= 2:
            test_suits.append(("Engine", 1, 2))
            test_suits.append(("Runtime", 2, 1))

        if torch.cuda.device_count() >= 4:
            test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)])

        parameters = [
            "model.embed_tokens.weight",
            "model.layers.0.input_layernorm.weight",
            "model.layers.1.self_attn.q_proj.weight",
            "model.layers.2.self_attn.k_proj.weight",
            "model.layers.3.self_attn.v_proj.weight",
            "model.layers.4.self_attn.o_proj.weight",
            "model.layers.5.mlp.gate_proj.weight",
            "model.layers.6.mlp.up_proj.weight",
            "model.layers.7.mlp.down_proj.weight",
            "model.layers.8.post_attention_layernorm.weight",
            "model.norm.weight",
            "lm_head.weight",
        ]

        for test_suit in test_suits:
            self.init_backend(*test_suit)
            for param_name in parameters:
Lianmin Zheng's avatar
Lianmin Zheng committed
123
                self.assert_weights_all_close(param_name, 100)
124
125
126
127
128
            self.close_engine_and_server()


if __name__ == "__main__":
    unittest.main()