test_get_weights_by_name.py 6.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import gc
import unittest

import numpy as np
import requests
import torch
from transformers import AutoModelForCausalLM

import sglang as sgl
from sglang.test.test_utils import (
11
    DEFAULT_MODEL_NAME_FOR_TEST,
12
13
14
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
15
    is_in_ci,
16
17
18
19
20
    popen_launch_server,
)
from sglang.utils import terminate_process


21
22
23
24
25
26
27
28
def _process_return(ret):
    if isinstance(ret, list) and len(ret) == 2:
        print(f"running assert_allclose on data parallel")
        np.testing.assert_allclose(ret[0], ret[1])
        return np.array(ret[0])
    return np.array(ret)


29
class TestGetWeightsByName(unittest.TestCase):
30

31
32
33
34
    def init_hf_model(self, model_name, tie_word_embeddings):
        self.hf_model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings
        ).to("cuda:0")
35

36
    def init_backend(self, backend, dp, tp, model_name):
37
38
39
40
41
        self.backend = backend
        self.dp = dp
        self.tp = tp
        if backend == "Engine":
            self.engine = sgl.Engine(
42
                model_path=model_name,
43
                random_seed=42,
44
45
                tp_size=tp,
                dp_size=dp,
46
47
48
            )
        else:
            self.process = popen_launch_server(
49
50
                model_name,
                DEFAULT_URL_FOR_TEST,
51
52
53
54
55
56
57
58
59
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=(
                    "--tp-size",
                    str(tp),
                    "--dp-size",
                    str(dp),
                ),
            )

60
61
62
63
64
    def clean_up(self):
        del self.hf_model
        gc.collect()
        torch.cuda.empty_cache()
        if self.backend == "Engine":
65
            self.engine.shutdown()
66
        else:
67
68
            terminate_process(self.process)

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    def assert_tie_word_embeddings(self, truncate_size):
        print(f"assert_tie_word_embeddings")
        if self.backend == "Engine":
            backend_ret = _process_return(
                self.engine.get_weights_by_name("lm_head.weight", truncate_size)
            )
        else:
            backend_ret = _process_return(
                requests.get(
                    f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
                    json={"name": "lm_head.weight", "truncate_size": truncate_size},
                ).json()
            )
        print(f"assert_tie_word_embeddings of hf and backend")
        assert np.allclose(
            self.hf_model.get_parameter("model.embed_tokens.weight")
            .cpu()
            .detach()
            .float()
            .numpy()[:truncate_size],
            backend_ret,
        )
        assert np.allclose(
            self.hf_model.get_parameter("lm_head.weight")
            .cpu()
            .detach()
            .float()
            .numpy()[:truncate_size],
            self.hf_model.get_parameter("model.embed_tokens.weight")
            .cpu()
            .detach()
            .float()
            .numpy()[:truncate_size],
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
104
    def assert_weights_all_close(self, param_name, truncate_size):
105
106
107
108
109
110
111
112
        print(
            f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
        )
        param = self.hf_model.get_parameter(param_name)[:truncate_size]
        param_np = param.cpu().detach().float().numpy()

        if self.backend == "Engine":
            engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
113
            engine_ret = _process_return(engine_ret)
114
115
116
117
            np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)

        if self.backend == "Runtime":
            runtime_ret = requests.get(
118
                f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
119
120
                json={"name": param_name, "truncate_size": truncate_size},
            ).json()
121
            runtime_ret = _process_return(runtime_ret)
122
123
            np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def test_get_weights_by_name(self):
        if is_in_ci():
            test_suits = [
                ("Engine", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
            ]
        else:
            test_suits = [
                ("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
                ("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST),
            ]
            if torch.cuda.device_count() >= 2:
                test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST))
                test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST))

            if torch.cuda.device_count() >= 4:
                test_suits.extend(
                    [
                        ("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
                        ("Runtime", 2, 2, DEFAULT_MODEL_NAME_FOR_TEST),
                    ]
                )
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

        parameters = [
            "model.embed_tokens.weight",
            "model.layers.0.input_layernorm.weight",
            "model.layers.1.self_attn.q_proj.weight",
            "model.layers.2.self_attn.k_proj.weight",
            "model.layers.3.self_attn.v_proj.weight",
            "model.layers.4.self_attn.o_proj.weight",
            "model.layers.5.mlp.gate_proj.weight",
            "model.layers.6.mlp.up_proj.weight",
            "model.layers.7.mlp.down_proj.weight",
            "model.layers.8.post_attention_layernorm.weight",
            "model.norm.weight",
            "lm_head.weight",
        ]

161
162
        truncate_size = 100

163
        for test_suit in test_suits:
164
165
166
167
168
169
            if test_suit[-1] == DEFAULT_MODEL_NAME_FOR_TEST:
                tie_word_embeddings = False
            else:
                tie_word_embeddings = True

            self.init_hf_model(test_suit[-1], tie_word_embeddings)
170
            self.init_backend(*test_suit)
171

172
            for param_name in parameters:
173
174
175
176
177
178
                self.assert_weights_all_close(param_name, truncate_size)

            if tie_word_embeddings:
                self.assert_tie_word_embeddings(truncate_size)

            self.clean_up()
179
180
181
182


if __name__ == "__main__":
    unittest.main()