"vscode:/vscode.git/clone" did not exist on "56d04089ef508003c684c90429046d90f2117547"
test_eagle.py 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# SPDX-License-Identifier: Apache-2.0

from unittest import mock

import pytest
import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
                         VllmConfig)
from vllm.v1.spec_decode.eagle import EagleProposer

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


def _create_proposer(method: str, k: int) -> EagleProposer:
    model_config = ModelConfig(model=model_dir,
                               task="generate",
                               max_model_len=100,
                               tokenizer=model_dir,
                               tokenizer_mode="auto",
                               dtype="auto",
                               seed=None,
                               trust_remote_code=False)

    # Choose model directory based on method
    draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir

    speculative_config = SpeculativeConfig(
        target_model_config=model_config,
        target_parallel_config=ParallelConfig(),
        model=draft_model_dir,
        method=method,
        num_speculative_tokens=k,
    )

    vllm_config = VllmConfig(model_config=model_config,
                             cache_config=CacheConfig(),
                             speculative_config=speculative_config,
                             device_config=DeviceConfig(device="cuda"),
                             parallel_config=ParallelConfig(),
                             load_config=LoadConfig(),
                             scheduler_config=SchedulerConfig())

    return EagleProposer(vllm_config=vllm_config, device='cuda')


def test_prepare_inputs():
    """
    cu_target_query_lens: [0, a, a + b, a + b + c]
    num_rejected_tokens: [n1, n2, n3]
    num_tokens_per_req: [a - n1, b - n2, c - n3]
    cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
    token_indices: [0, 1, ..., a - n1 - 1,
                    a, a + 1, ..., a + b - n2 - 1,
                    a + b, a + b + 1, ..., a + b + c - n3 - 1]
    """
    device = torch.device('cuda')

    # a = 4, b = 7, c = 5
    # n1 = 1, n2 = 3, n3 = 2

    # Cumulative lengths: [0, 4, 11, 16]
    cu_target_query_lens = torch.tensor([0, 4, 11, 16],
                                        dtype=torch.int32,
                                        device=device)

    # Rejected tokens per request: [1, 3, 2]
    num_rejected_tokens = torch.tensor([1, 3, 2],
                                       dtype=torch.int32,
                                       device=device)

    # Expected calculations:
    # query_len_per_req = [4, 7, 5]
    # num_tokens_per_req = [3, 4, 3]  (after subtracting rejected tokens)
    # Expected cumulative counts: [0, 3, 7, 10]
    expected_cu_num_tokens = torch.tensor([0, 3, 7, 10],
                                          dtype=torch.int32,
                                          device=device)

    # Expected token indices (mapped from original positions):
    # First request: indices 0, 1, 2      (keeping first 3 from positions 0-3)
    # Second request: indices 4, 5, 6, 7  (keeping first 4 from positions 4-10)
    # Third request: indices 11, 12, 13   (keeping first 3 from positions 11-15)
    expected_token_indices = torch.tensor(
        [
            0,
            1,
            2,  # First request: 3 tokens (4-1)
            4,
            5,
            6,
            7,  # Second request: 4 tokens (7-3)
            11,
            12,
            13  # Third request: 3 tokens (5-2)
        ],
        dtype=torch.int32,
        device=device)

    cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
        cu_target_query_lens, num_rejected_tokens)

    assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
    assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
    assert torch.equal(token_indices, expected_token_indices)


@pytest.mark.parametrize(
    "method,proposer_helper,draft_model_dir,target_attribute_path", [
        ("eagle", lambda k: _create_proposer("eagle", k), eagle_dir,
         ('lm_head', )),
        ("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
         ('model', 'embed_tokens')),
    ])
118
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
119
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
120
121
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
122
                    proposer_helper, draft_model_dir, target_attribute_path):
123

124
125
126
    # Setup model mock
    mock_model = mock.MagicMock()
    mock_get_model.return_value = mock_model
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    # Setup mocks for attention layers
    target_attn_layers = {
        "target_attn_1": mock.MagicMock(),
        "target_attn_2": mock.MagicMock()
    }
    # Draft model has one extra attention layer compared to target model
    all_attn_layers = {
        **target_attn_layers, "draft_extra_attn": mock.MagicMock()
    }

    # Make mock_get_layers return different values for each call
    mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]

141
142
143
144
145
    # Setup mock for pp group to return the appropriate value for world size
    mock_pp_group = mock.MagicMock()
    mock_pp_group.world_size = 2 if method == "eagle" else 1
    mock_get_pp_group.return_value = mock_pp_group

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    # Setup target model with the appropriate attributes
    target_model = mock.MagicMock()

    # Create the necessary attributes on the target model
    current_obj = target_model
    for i, attr in enumerate(target_attribute_path):
        if i == len(target_attribute_path) - 1:
            # Set the last attribute in the path to a MagicMock
            setattr(current_obj, attr, mock.MagicMock())
        else:
            # Create intermediate objects if needed
            setattr(current_obj, attr, mock.MagicMock())
            current_obj = getattr(current_obj, attr)

    # Create proposer using the helper function
    proposer = proposer_helper(k=8)

    # Call the method under test
    proposer.load_model(target_model)

    # Verify common interactions
167
    mock_get_model.assert_called_once()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    # Verify the specific attribute sharing based on the method
    if method == "eagle":
        assert proposer.model.lm_head == target_model.lm_head
    else:
        assert proposer.model.model.embed_tokens == \
            target_model.model.embed_tokens


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens):
    # Use GPU device
    device = torch.device('cuda')

    # Setup test parameters
    batch_size = 2
    seq_len_1 = 5
    seq_len_2 = 3
    total_tokens = seq_len_1 + seq_len_2
    vocab_size = 100

    # Create proposer first so we can use its actual hidden_size
    proposer = _create_proposer("eagle", num_speculative_tokens)
    # Get the hidden_size from the proposer to ensure consistency
    hidden_size = proposer.hidden_size

    # Helper to create deterministic logits that will produce specific tokens
    def create_deterministic_logits(token_ids):
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
        for i, token_id in enumerate(token_ids):
            logits[i, token_id] = 100.0
        return logits

    # We mock a model that returns deterministic logits
    # Sequence 1: 42, 43, 44, ...
    # Sequence 2: 60, 61, 62, ...
    base_token_ids = [42, 60]

    # Skip loading the model and replace it with a mock directly
    # Create the mock model with deterministic outputs
    model_mock = mock.MagicMock()

    # Setup for model forward calls
    forward_returns = []
    for i in range(num_speculative_tokens):
        if i == 0:
            # First call uses all tokens
            h_logits = torch.zeros(total_tokens, hidden_size, device=device)
            h_states = torch.zeros(total_tokens, hidden_size, device=device)
        else:
            # Subsequent calls use batch_size tokens
            h_logits = torch.zeros(batch_size, hidden_size, device=device)
            h_states = torch.zeros(batch_size, hidden_size, device=device)
        forward_returns.append((h_logits, h_states))

    # For single token case, we only need the first item;
    # for multi-token, we need the sequence
    if num_speculative_tokens == 1:
        model_mock.return_value = forward_returns[0]
    else:
        model_mock.side_effect = forward_returns

    # Setup for compute_logits calls
    logits_returns = []
    for i in range(num_speculative_tokens):
        # For each call, increment the base token IDs
        current_tokens = [base_id + i for base_id in base_token_ids]
        logits_returns.append(create_deterministic_logits(current_tokens))

    if num_speculative_tokens == 1:
        model_mock.compute_logits.return_value = logits_returns[0]
    else:
        model_mock.compute_logits.side_effect = logits_returns

    # Assign the mock to the proposer
    proposer.model = model_mock

    # Create input tensors
    cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
                                 dtype=torch.int32,
                                 device=device)

    target_token_ids = torch.randint(0,
                                     vocab_size, (total_tokens, ),
                                     device=device)
    target_positions = torch.cat([
        torch.arange(seq_len_1, device=device),
        torch.arange(seq_len_2, device=device)
    ])
    target_hidden_states = torch.randn(total_tokens,
                                       hidden_size,
                                       device=device)
    target_slot_mapping = torch.randint(0,
                                        100, (total_tokens, ),
                                        device=device)
    next_token_ids = torch.randint(0,
                                   vocab_size, (batch_size, ),
                                   dtype=torch.int32,
                                   device=device)
    block_table = torch.randint(0, 10, (batch_size, 10), device=device)

    sampling_metadata = mock.MagicMock()

    # Call the method under test
    result = proposer.propose(target_token_ids=target_token_ids,
                              target_positions=target_positions,
                              target_hidden_states=target_hidden_states,
                              target_slot_mapping=target_slot_mapping,
                              next_token_ids=next_token_ids,
                              cu_num_tokens=cu_num_tokens,
                              block_table=block_table,
                              sampling_metadata=sampling_metadata)

    assert result.shape == (batch_size, num_speculative_tokens)

    # Create expected tokens based on our token pattern
    if num_speculative_tokens == 1:
        # Example for num_speculative_tokens=1:
        # [[42], [60]]
        expected_tokens = torch.tensor(
            [[base_token_ids[0]], [base_token_ids[1]]], device=device)
    else:
        # Example for num_speculative_tokens=3:
        # [[42, 43, 44], [60, 61, 62]]
        expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
                                      dtype=torch.int64,
                                      device=device)
        for i in range(batch_size):
            for j in range(num_speculative_tokens):
                expected_tokens[i, j] = base_token_ids[i] + j

    # Verify all tokens match our expectations
    assert torch.equal(result, expected_tokens)