"vscode:/vscode.git/clone" did not exist on "e8eb0490ce098b1add05877363b185f3a7b570c5"
gpu_model_runner.py 7.08 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import numpy as np
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.parallel_state import get_tp_group
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile


class V1ZeroModelRunner():
    def __init__(self):
        self.last_sampled_token_ids = None
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []
        self.last_sampler_event = torch.cuda.Event(enable_timing=False)
        self.last_sampler_host_tokens = None
        self.token_ids_cpu_fix_recode = []
    
    def set_last_sampled_token_ids(self, sampled_token_ids):
        self.last_sampled_token_ids = sampled_token_ids
        self.last_sampled_req_ids = []
        self.last_sampled_token_lens = []

v1_zero_overhead = V1ZeroModelRunner()

def zero_prepare_inputs(runner, scheduler_output, input_ids):
    req_ids = runner.input_batch.req_ids
    update_req_indices = []
    input_ids_indices = []
    token_idx = 0
    if v1_zero_overhead.last_sampled_token_ids is None:
        return
    sampled_tokens_num = v1_zero_overhead.last_sampled_token_ids.shape[1]
    for req_id in req_ids:
        if req_id in v1_zero_overhead.last_sampled_req_ids:
            req_idx = v1_zero_overhead.last_sampled_req_ids.index(req_id) * sampled_tokens_num
            update_req_indices.append(req_idx)
            input_ids_indices.append(token_idx)
        token_idx += scheduler_output.num_scheduled_tokens[req_id]
    if len(update_req_indices) > 0:
        update_req_indices_tensor = async_tensor_h2d(update_req_indices, torch.int32,
                                                    runner.device,
                                                    True)
        input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
                                                    runner.device,
                                                    True)
        last_sampled_token_ids = v1_zero_overhead.last_sampled_token_ids.flatten()
        for i in range(sampled_tokens_num):
            input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]


def execute_model_sampled(runner, max_gen_len, sampled_token_ids, 
                          discard_sampled_tokens_req_indices, scheduler_output,
                          sampling_metadata,
                          hidden_states,
                          sample_hidden_states,
                          aux_hidden_states,
                          spec_decode_metadata,
                          attn_metadata,
                          logprobs_lists,
                          prompt_logprobs_dict,
                          finished_sending,
                          finished_recving,
                          num_nans_in_logits
                          ):
    fix_req_ids = None
    fix_sampled_token_ids = None
    if max_gen_len == 1:
        # No spec decode tokens.
        if v1_zero_overhead.last_sampler_host_tokens != None:
            v1_zero_overhead.last_sampler_event.synchronize()
            fix_sampled_token_ids = v1_zero_overhead.last_sampler_host_tokens.tolist()
            for req_idx, start_idx, end_idx in v1_zero_overhead.token_ids_cpu_fix_recode:
                runner.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
        fix_req_ids = v1_zero_overhead.last_sampled_req_ids
        for req_idx, req_id in enumerate(fix_req_ids):
            if req_id in runner.requests:
                req_state = runner.requests[req_id]
                token_idx = v1_zero_overhead.last_sampled_token_lens[req_idx]
                req_state.output_token_ids[token_idx] = fix_sampled_token_ids[req_idx][0]
        v1_zero_overhead.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
        v1_zero_overhead.last_sampler_event.record()
        v1_zero_overhead.set_last_sampled_token_ids(sampled_token_ids)
        valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
    else:
        # Includes spec decode tokens.
        valid_sampled_token_ids = runner.rejection_sampler.parse_output(
            sampled_token_ids,
            runner.input_batch.vocab_size,
        )
    # Mask out the sampled tokens that should not be sampled.
    for i in discard_sampled_tokens_req_indices:
        valid_sampled_token_ids[i].clear()

    # Cache the sampled tokens in the model runner, so that the scheduler
    # doesn't need to send them back.
    # NOTE(woosuk): As an exception, when using PP, the scheduler sends
    # the sampled tokens back, because there's no direct communication
    # between the first-stage worker and the last-stage worker.
    v1_zero_overhead.token_ids_cpu_fix_recode.clear()
    for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
        if not sampled_ids:
            continue

        start_idx = runner.input_batch.num_tokens_no_spec[req_idx]
        end_idx = start_idx + len(sampled_ids)
        assert end_idx <= runner.max_model_len, (
            "Sampled token IDs exceed the max model length. "
            f"Total number of tokens: {end_idx} > max_model_len: "
            f"{runner.max_model_len}")

        runner.input_batch.token_ids_cpu[req_idx,
                                        start_idx:end_idx] = sampled_ids
        v1_zero_overhead.token_ids_cpu_fix_recode.append([req_idx, start_idx, end_idx])
        runner.input_batch.num_tokens_no_spec[req_idx] = end_idx
        runner.input_batch.num_tokens[req_idx] = end_idx
        req_id = runner.input_batch.req_ids[req_idx]
        if req_id in runner.requests:
            req_state = runner.requests[req_id]
            v1_zero_overhead.last_sampled_req_ids.append(req_id)
            v1_zero_overhead.last_sampled_token_lens.append(len(req_state.output_token_ids))
            req_state.output_token_ids.extend(sampled_ids)
    if not runner.speculative_config:
        # Speculative decoding is not enabled.
        spec_token_ids = None
    else:
        spec_token_ids = runner.propose_draft_token_ids(
            scheduler_output,
            valid_sampled_token_ids,
            sampling_metadata,
            hidden_states,
            sample_hidden_states,
            aux_hidden_states,
            spec_decode_metadata,
            attn_metadata,
        )

    # Clear KVConnector state after all KVs are generated.
    if has_kv_transfer_group():
        get_kv_transfer_group().clear_connector_metadata()

    runner.eplb_step()

    model_output = ZeroV1ModelRunnerOutput(
        req_ids=runner.input_batch.req_ids,
        req_id_to_index=runner.input_batch.req_id_to_index,
        sampled_token_ids=valid_sampled_token_ids,
        spec_token_ids=spec_token_ids,
        logprobs=logprobs_lists,
        prompt_logprobs_dict=prompt_logprobs_dict,
        pooler_output=[],
        finished_sending=finished_sending,
        finished_recving=finished_recving,
        num_nans_in_logits=num_nans_in_logits,
        fix_req_ids = fix_req_ids,
        fix_sampled_token_ids = fix_sampled_token_ids
    )
    return model_output