test_hidden_states.py 4.63 KB
Newer Older
1
2
3
4
5
6
import unittest

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import sglang as sgl
7
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
8
9


10
class TestHiddenState(CustomTestCase):
11
12
    def test_return_hidden_states(self):
        prompts = ["Today is", "Today is a sunny day and I like"]
13
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
14
15
16
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        input_ids = tokenizer(prompts).input_ids

17
18
19
20
        sampling_params = {
            "temperature": 0,
            "max_new_tokens": 8,
        }
21
22
23
24
25

        engine = sgl.Engine(
            model_path=model_path,
            random_seed=42,
            skip_tokenizer_init=True,
26
            enable_return_hidden_states=True,
27
        )
28
29
30
31
32
        outputs = engine.generate(
            input_ids=input_ids,
            sampling_params=sampling_params,
            return_hidden_states=True,
        )
33
34
35
36
        engine.shutdown()

        for output in outputs:
            self.assertEqual(len(output["meta_info"]["hidden_states"]), 8)
37
38
39
40
41
            for i in range(len(output["meta_info"]["hidden_states"])):
                assert isinstance(output["meta_info"]["hidden_states"][i], list)
                output["meta_info"]["hidden_states"][i] = torch.tensor(
                    output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
                )
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        # Checks that splicing of the batch was done correctly
        self.assertGreater(
            outputs[1]["meta_info"]["hidden_states"][0].shape[0],
            outputs[0]["meta_info"]["hidden_states"][0].shape[0],
        )

        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda"
        )

        for input_id, output in zip(input_ids, outputs):
            with torch.inference_mode():
                hf_out = model(
                    torch.tensor(
56
                        [input_id + output["output_ids"][:-1]], device=model.device
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
                    ),
                    output_hidden_states=True,
                )
            print("=== HF Hiddens ===")
            print(hf_out["hidden_states"][-1][0])
            sg_hidden_states = torch.cat(
                [
                    i.unsqueeze(0) if len(i.shape) == 1 else i
                    for i in output["meta_info"]["hidden_states"]
                ]
            ).to("cuda")
            print("=== SRT Hiddens ===")
            print(sg_hidden_states)

            print(
                f"Max diff: {torch.max(torch.abs(hf_out['hidden_states'][-1][0] - sg_hidden_states))}"
            )

75
            atol = 0.8
76
77
78
79
80
81
82
83
84
            self.assertTrue(
                torch.allclose(
                    hf_out["hidden_states"][-1][0],
                    sg_hidden_states,
                    atol=atol,
                    rtol=0,
                )
            )

85
86
    def test_repeatedly_changes_hidden_states(self):
        prompts = ["Today is", "Today is a sunny day and I like"]
87
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
88
89
90
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        input_ids = tokenizer(prompts).input_ids

91
        sampling_params = {
92
93
94
95
96
97
98
99
            "temperature": 0,
            "max_new_tokens": 8,
        }

        engine = sgl.Engine(
            model_path=model_path,
            random_seed=42,
            skip_tokenizer_init=True,
100
            enable_return_hidden_states=True,
101
102
        )
        outputs_completion_first_round = engine.generate(
103
104
105
            input_ids=input_ids,
            sampling_params=sampling_params,
            return_hidden_states=True,
106
107
        )
        outputs_hidden_state = engine.generate(
108
109
110
            input_ids=input_ids,
            sampling_params=sampling_params,
            return_hidden_states=False,
111
112
113
        )

        outputs_completion_last_round = engine.generate(
114
115
116
            input_ids=input_ids,
            sampling_params=sampling_params,
            return_hidden_states=True,
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        )
        engine.shutdown()

        for (
            output_completion_first_round,
            output_hidden_state,
            output_completion_last_round,
        ) in zip(
            outputs_completion_first_round,
            outputs_hidden_state,
            outputs_completion_last_round,
        ):
            self.assertEqual(
                len(output_completion_first_round["meta_info"]["hidden_states"]), 8
            )
            self.assertNotIn("hidden_states", output_hidden_state["meta_info"])
            self.assertEqual(
                len(output_completion_last_round["meta_info"]["hidden_states"]), 8
            )

137
138
139

if __name__ == "__main__":
    unittest.main()