test_get_weights_by_name.py 6.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import gc
import unittest

import numpy as np
import requests
import torch
from transformers import AutoModelForCausalLM

import sglang as sgl
from sglang.test.test_utils import (
11
    DEFAULT_MODEL_NAME_FOR_TEST,
12
13
14
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
15
    CustomTestCase,
16
    is_in_ci,
17
18
19
20
21
    popen_launch_server,
)
from sglang.utils import terminate_process


22
23
24
25
26
27
28
29
def _process_return(ret):
    if isinstance(ret, list) and len(ret) == 2:
        print(f"running assert_allclose on data parallel")
        np.testing.assert_allclose(ret[0], ret[1])
        return np.array(ret[0])
    return np.array(ret)


30
class TestGetWeightsByName(CustomTestCase):
31

32
33
34
35
    def init_hf_model(self, model_name, tie_word_embeddings):
        self.hf_model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings
        ).to("cuda:0")
36

37
    def init_backend(self, backend, dp, tp, model_name):
38
39
40
41
42
        self.backend = backend
        self.dp = dp
        self.tp = tp
        if backend == "Engine":
            self.engine = sgl.Engine(
43
                model_path=model_name,
44
                random_seed=42,
45
46
                tp_size=tp,
                dp_size=dp,
47
48
49
            )
        else:
            self.process = popen_launch_server(
50
51
                model_name,
                DEFAULT_URL_FOR_TEST,
52
53
54
55
56
57
58
59
60
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=(
                    "--tp-size",
                    str(tp),
                    "--dp-size",
                    str(dp),
                ),
            )

61
62
63
64
65
    def clean_up(self):
        del self.hf_model
        gc.collect()
        torch.cuda.empty_cache()
        if self.backend == "Engine":
66
            self.engine.shutdown()
67
        else:
68
69
            terminate_process(self.process)

70
    def assert_tie_word_embeddings(self, truncate_size):
Lianmin Zheng's avatar
Lianmin Zheng committed
71
        print("assert_tie_word_embeddings")
72
73
74
75
76
77
78
79
80
81
82
        if self.backend == "Engine":
            backend_ret = _process_return(
                self.engine.get_weights_by_name("lm_head.weight", truncate_size)
            )
        else:
            backend_ret = _process_return(
                requests.get(
                    f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
                    json={"name": "lm_head.weight", "truncate_size": truncate_size},
                ).json()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
83
        print("assert_tie_word_embeddings of hf and backend")
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        assert np.allclose(
            self.hf_model.get_parameter("model.embed_tokens.weight")
            .cpu()
            .detach()
            .float()
            .numpy()[:truncate_size],
            backend_ret,
        )
        assert np.allclose(
            self.hf_model.get_parameter("lm_head.weight")
            .cpu()
            .detach()
            .float()
            .numpy()[:truncate_size],
            self.hf_model.get_parameter("model.embed_tokens.weight")
            .cpu()
            .detach()
            .float()
            .numpy()[:truncate_size],
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
105
    def assert_weights_all_close(self, param_name, truncate_size):
106
107
108
109
110
111
112
113
        print(
            f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
        )
        param = self.hf_model.get_parameter(param_name)[:truncate_size]
        param_np = param.cpu().detach().float().numpy()

        if self.backend == "Engine":
            engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
114
            engine_ret = _process_return(engine_ret)
115
116
117
118
            np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)

        if self.backend == "Runtime":
            runtime_ret = requests.get(
119
                f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
120
121
                json={"name": param_name, "truncate_size": truncate_size},
            ).json()
122
            runtime_ret = _process_return(runtime_ret)
123
124
            np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    def test_get_weights_by_name(self):
        if is_in_ci():
            test_suits = [
                ("Engine", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
            ]
        else:
            test_suits = [
                ("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
                ("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST),
            ]
            if torch.cuda.device_count() >= 2:
                test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST))
                test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST))

            if torch.cuda.device_count() >= 4:
                test_suits.extend(
                    [
                        ("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
                        ("Runtime", 2, 2, DEFAULT_MODEL_NAME_FOR_TEST),
                    ]
                )
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        parameters = [
            "model.embed_tokens.weight",
            "model.layers.0.input_layernorm.weight",
            "model.layers.1.self_attn.q_proj.weight",
            "model.layers.2.self_attn.k_proj.weight",
            "model.layers.3.self_attn.v_proj.weight",
            "model.layers.4.self_attn.o_proj.weight",
            "model.layers.5.mlp.gate_proj.weight",
            "model.layers.6.mlp.up_proj.weight",
            "model.layers.7.mlp.down_proj.weight",
            "model.layers.8.post_attention_layernorm.weight",
            "model.norm.weight",
            "lm_head.weight",
        ]

162
163
        truncate_size = 100

164
        for test_suit in test_suits:
165
166
167
168
169
170
            if test_suit[-1] == DEFAULT_MODEL_NAME_FOR_TEST:
                tie_word_embeddings = False
            else:
                tie_word_embeddings = True

            self.init_hf_model(test_suit[-1], tie_word_embeddings)
171
            self.init_backend(*test_suit)
172

173
            for param_name in parameters:
174
175
176
177
178
179
                self.assert_weights_all_close(param_name, truncate_size)

            if tie_word_embeddings:
                self.assert_tie_word_embeddings(truncate_size)

            self.clean_up()
180
181
182
183


if __name__ == "__main__":
    unittest.main()