test_forward_split_prefill.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Test forward_split_prefill functionality.

Usage:
python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill
or
python3 test_forward_split_prefill.py
"""

import unittest

import numpy as np
import torch

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
17
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
19
20
21
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
22
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase


class TestForwardSplitPrefill(CustomTestCase):
    """Test cases for forward_split_prefill functionality."""

    @classmethod
    def setUpClass(cls):
        """Set up the test environment once for all tests."""
        cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.tp_size = 1
        cls.device = "cuda"

        # Initialize server args
        cls.server_args = ServerArgs(
            model_path=cls.model_path,
            tokenizer_path=cls.model_path,
            host="127.0.0.1",
            disable_cuda_graph=True,  # Disable CUDA graph for testing split prefill
            disable_hybrid_swa_memory=True,
            port=30000,
            tp_size=cls.tp_size,
            mem_fraction_static=0.8,
            trust_remote_code=True,
        )

        cls.port_args = PortArgs.init_new(cls.server_args)

        # Load model and tokenizer
        cls.model_config = ModelConfig.from_server_args(cls.server_args)
        cls.model_runner = ModelRunner(
            model_config=cls.model_config,
            mem_fraction_static=cls.server_args.mem_fraction_static,
            gpu_id=0,
            tp_rank=0,
            tp_size=cls.tp_size,
            pp_rank=0,
            pp_size=1,
            nccl_port=cls.port_args.nccl_port,
            server_args=cls.server_args,
        )

        cls.tokenizer = get_tokenizer(
            cls.server_args.tokenizer_path,
            tokenizer_mode=cls.server_args.tokenizer_mode,
            trust_remote_code=cls.server_args.trust_remote_code,
        )

        print(
            f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}"
        )

    def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True):
        """Prepare a test batch for split prefill testing."""
        # Create synthetic input
        input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32)

        sampling_params = SamplingParams(
            temperature=0.0,
            max_new_tokens=8,
        )

        reqs = []
        for i in range(batch_size):
            req = Req(
                rid=i,
                origin_input_text="",
                origin_input_ids=list(input_ids[i]),
                sampling_params=sampling_params,
            )
            req.prefix_indices = []
            req.fill_ids = req.origin_input_ids
            req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
            req.logprob_start_len = len(req.origin_input_ids) - 1
            reqs.append(req)

        batch = ScheduleBatch.init_new(
            reqs=reqs,
            req_to_token_pool=self.model_runner.req_to_token_pool,
            token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
            tree_cache=None,
            model_config=self.model_config,
            enable_overlap=False,
            spec_algorithm=SpeculativeAlgorithm.NONE,
            enable_custom_logit_processor=False,
        )
        if is_split_prefill:
            batch.prepare_for_split_prefill()
        else:
            batch.prepare_for_extend()

        # Create forward batch
        model_worker_batch = batch.get_model_worker_batch()
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)

        return forward_batch

    def test_split_prefill_functionality(self):
        """Test that split prefill can complete successfully."""
        print("\n=== Testing split prefill functionality ===")

        forward_batch = self.prepare_test_batch(batch_size=2, input_len=64)

        # Reset split index
        forward_batch.split_index = 0

        # Test split prefill in chunks
        num_layers = self.model_config.num_hidden_layers
        chunk_size = max(1, num_layers // 4)  # Split into 4 chunks

        results = []
        split_count = 0

        while forward_batch.split_index < num_layers:
            print(
                f"Processing split {split_count}, split_index: {forward_batch.split_index}"
            )

            result = self.model_runner.forward_split_prefill(
                forward_batch=forward_batch,
                reinit_attn_backend=(split_count == 0),
                forward_count=chunk_size,
            )

            results.append(result)
            split_count += 1

            # Verify split_index is updated correctly
            expected_next_index = min(split_count * chunk_size, num_layers)
            self.assertEqual(forward_batch.split_index, expected_next_index)

        # The last result should contain logits
        self.assertIsNotNone(results[-1], "Final split should return logits")
        print(f"Split prefill completed in {split_count} splits")

    def test_split_prefill_vs_normal_prefill(self):
        """Test that split prefill produces the same results as normal prefill."""
        print("\n=== Testing split prefill vs normal prefill consistency ===")

        forward_batch_normal = self.prepare_test_batch(
            batch_size=2, input_len=128, is_split_prefill=False
        )
        forward_batch_split = self.prepare_test_batch(
            batch_size=2, input_len=128, is_split_prefill=True
        )

        # Ensure same input
        forward_batch_split.input_ids = forward_batch_normal.input_ids.clone()
        forward_batch_split.positions = forward_batch_normal.positions.clone()

        # Method 1: Normal extend (prefill)
        print("Running normal extend (prefill)...")
        normal_result = self.model_runner.forward_extend(forward_batch_normal)

        # Method 2: Split prefill
        print("Running split prefill...")
        num_layers = self.model_config.num_hidden_layers
        chunk_size = max(1, num_layers // 3)  # Split into 3 chunks

        split_result = None

        while forward_batch_split.split_index < num_layers:
            result = self.model_runner.forward_split_prefill(
                forward_batch=forward_batch_split,
                forward_count=chunk_size,
            )
            if result is not None:
                split_result = result

        # Compare results
        self.assertIsNotNone(normal_result, "Normal prefill should return result")
        self.assertIsNotNone(split_result, "Split prefill should return result")

        # Compare logits shapes
        self.assertEqual(
            normal_result.next_token_logits.shape,
            split_result.next_token_logits.shape,
            "Logits shapes should match",
        )

        # Compare logits values (should be very close due to same computation)
        # Use a larger tolerance for numerical differences in split computation
        torch.testing.assert_close(
            normal_result.next_token_logits,
            split_result.next_token_logits,
            rtol=1e-3,
            atol=1e-3,
            msg="Split prefill and normal prefill should produce similar logits",
        )

        print("✓ Split prefill and normal prefill produce consistent results")

    def test_split_prefill_different_chunk_sizes(self):
        """Test split prefill with different chunk sizes."""
        print("\n=== Testing split prefill with different chunk sizes ===")

        num_layers = self.model_config.num_hidden_layers
        chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers]

        # Prepare identical batches for each test
        base_batch = self.prepare_test_batch(batch_size=1, input_len=16)
        base_input_ids = base_batch.input_ids.clone()
        base_positions = base_batch.positions.clone()

        results = []

        for chunk_size in chunk_sizes:
            if chunk_size > num_layers:
                continue

            print(f"Testing chunk size: {chunk_size}")

            # Prepare fresh batch
            forward_batch = self.prepare_test_batch(batch_size=1, input_len=16)
            forward_batch.input_ids = base_input_ids.clone()
            forward_batch.positions = base_positions.clone()
            forward_batch.split_index = 0

            # Run split prefill
            split_result = None

            while forward_batch.split_index < num_layers:
                result = self.model_runner.forward_split_prefill(
                    forward_batch=forward_batch,
                    forward_count=chunk_size,
                )
                if result is not None:
                    split_result = result

            self.assertIsNotNone(
                split_result,
                f"Split prefill should succeed with chunk_size={chunk_size}",
            )
            results.append(split_result)

        # Compare all results should be identical (same input, same computation)
        if len(results) > 1:
            for i, result in enumerate(results[1:], 1):
                torch.testing.assert_close(
                    results[0].next_token_logits,
                    result.next_token_logits,
                    rtol=1e-3,
                    atol=1e-3,
                    msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})",
                )

        print("✓ All chunk sizes produce consistent results")

    def test_split_prefill_edge_cases(self):
        """Test edge cases for split prefill."""
        print("\n=== Testing split prefill edge cases ===")

        # Test with single layer chunks
        forward_batch = self.prepare_test_batch(batch_size=1, input_len=8)

        # Process one layer at a time
        num_layers = self.model_config.num_hidden_layers
        for layer_idx in range(num_layers):
            result = self.model_runner.forward_split_prefill(
                forward_batch=forward_batch,
                reinit_attn_backend=(layer_idx == 0),
                forward_count=1,  # One layer at a time
            )

            if layer_idx == num_layers - 1:
                # Last layer should return result
                self.assertIsNotNone(result, "Last layer should return logits")
            else:
                # Intermediate layers should return None
                self.assertIsNone(result, f"Layer {layer_idx} should return None")

        print("✓ Single layer processing works correctly")


if __name__ == "__main__":
    unittest.main()