"examples/pooling/token_classify/ner_offline.py" did not exist on "5f696c33b1fbf33fe91ecdd958874b9dd52f79b4"
model_input_split.py 15.1 KB
Newer Older
lizhigong's avatar
lizhigong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import torch
from vllm.attention.backends.flashmla import FlashMLAMetadata
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import async_tensor_h2d

def cumsum(lst):
    cum_lst = [0]
    sum = 0
    for i in range(0, len(lst)):
        sum = sum + lst[i]
        cum_lst.append(sum)
    return cum_lst

def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
    from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
    query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
    batch_size_split = [batch_size_left, batch_size_right]
    split_input_tokens = torch.split(model_input.input_tokens, query_tokens_split, dim=0)
    split_input_positions = torch.split(model_input.input_positions, query_tokens_split, dim=0)
    seq_lens_left = model_input.attn_metadata.seq_lens[0:batch_size_left]
    seq_lens_right = model_input.attn_metadata.seq_lens[batch_size_left:]
    query_lens_left = model_input.query_lens[0:batch_size_left]
    query_lens_right = model_input.query_lens[batch_size_left:]
    split_seq_lens_tensor = torch.split(model_input.attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
    split_block_tables = torch.split(model_input.attn_metadata.block_tables, batch_size_split, dim=0)
    num_prefills_left = 0
    num_prefills_right = 0
    num_prefill_tokens_left = 0
    num_prefill_tokens_right = 0
    num_decode_tokens_left = 0
    num_decode_tokens_right = 0
    max_prefill_seq_len_left = 0
    max_prefill_seq_len_right = 0
    max_decode_seq_len_left = 0
    max_decode_seq_len_right = 0
    max_decode_query_len_left = None
    max_decode_query_len_right = None
    encoder_seq_lens_left = None
    encoder_seq_lens_right = None
    encoder_seq_lens_tensor_left = None
    encoder_seq_lens_tensor_right = None
    max_encoder_seq_len_left = None
    max_encoder_seq_len_right = None
    num_encoder_tokens_left = None
    num_encoder_tokens_right = None
    cross_slot_mapping_left = None
    cross_slot_mapping_right = None
    cross_block_tables_left = None
    cross_block_tables_right = None
    if model_input.is_prompt:
        num_prefills_left = batch_size_left
        num_prefills_right = batch_size_right
        num_prefill_tokens_left = sum(model_input.query_lens[0:batch_size_left])
        num_prefill_tokens_right = sum(model_input.query_lens[batch_size_left:])
        max_prefill_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
        max_prefill_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
    else:
        num_decode_tokens_left = batch_size_left
        num_decode_tokens_right = batch_size_right
        max_decode_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
        max_decode_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
    split_slot_mapping = torch.split(model_input.attn_metadata.slot_mapping, query_tokens_split, dim=0)
    max_query_len_left = max(model_input.query_lens[0:batch_size_left])
    max_query_len_right = max(model_input.query_lens[batch_size_left:])
    zero_tensor = torch.tensor([0], device=self_device, dtype=torch.int32)
    query_start_loc_left_list = cumsum(query_lens_left)
    query_start_loc_right_list = cumsum(query_lens_right)
    query_start_loc_left = async_tensor_h2d(query_start_loc_left_list, torch.int32,
                                            self_device,
                                            True)
    query_start_loc_right = async_tensor_h2d(query_start_loc_right_list, torch.int32,
                                            self_device,
                                            True)
    seq_start_loc_left = torch.cat((zero_tensor, split_seq_lens_tensor[0].cumsum(dim=0)), dim=0).to(torch.int32)
    seq_start_loc_right = torch.cat((zero_tensor, split_seq_lens_tensor[1].cumsum(dim=0)), dim=0).to(torch.int32)

    split_context_lens_tensor = torch.split(model_input.attn_metadata.context_lens_tensor, batch_size_split, dim=0)
    request_ids_to_seq_ids_left = {}
    request_ids_to_seq_ids_right = {}
    counter = 0
    for key, value in model_input.request_ids_to_seq_ids.items():
        if counter < batch_size_left:
            request_ids_to_seq_ids_left[key] = value
        else:
            request_ids_to_seq_ids_right[key] = value
        counter += 1
    seq_groups_left = None
    seq_groups_right = None
    if model_input.sampling_metadata.seq_groups is not None:
        seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left]
        seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
    selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
    selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
    
    if isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata):
        block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
        block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
        attn_metadata_left = ROCmFlashAttentionMetadata(
            seq_lens_tensor = split_seq_lens_tensor[0],
            max_decode_seq_len = max_decode_seq_len_left,
            block_tables = split_block_tables[0],
            num_prefills = num_prefills_left,
            num_prefill_tokens = num_prefill_tokens_left,
            num_decode_tokens = num_decode_tokens_left,
            slot_mapping = split_slot_mapping[0],
            multi_modal_placeholder_index_maps = {},
            enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
            seq_lens = seq_lens_left,
            max_prefill_seq_len = max_prefill_seq_len_left,
            use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
            max_query_len = max_query_len_left,
            query_start_loc = query_start_loc_left,
            seq_start_loc = seq_start_loc_left,
            context_lens_tensor = split_context_lens_tensor[0],
            max_decode_query_len = max_decode_query_len_left,
            _cached_prefill_metadata = None,
            _cached_decode_metadata = None,
            tree_attention_masks_tensor = None,
            block_tables_list = block_tables_list_left,
            encoder_seq_lens = encoder_seq_lens_left,
            encoder_seq_lens_tensor = encoder_seq_lens_tensor_left,
            max_encoder_seq_len = max_encoder_seq_len_left,
            num_encoder_tokens = num_encoder_tokens_left,
            cross_slot_mapping = cross_slot_mapping_left,
            cross_block_tables = cross_block_tables_left,
        )
        attn_metadata_right = ROCmFlashAttentionMetadata(
            seq_lens_tensor = split_seq_lens_tensor[1],
            max_decode_seq_len = max_decode_seq_len_right,
            block_tables = split_block_tables[1],
            num_prefills = num_prefills_right,
            num_prefill_tokens = num_prefill_tokens_right,
            num_decode_tokens = num_decode_tokens_right,
            slot_mapping = split_slot_mapping[1],
            multi_modal_placeholder_index_maps = {},
            enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
            seq_lens = seq_lens_right,
            max_prefill_seq_len = max_prefill_seq_len_right,
            use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
            max_query_len = max_query_len_right,
            query_start_loc = query_start_loc_right,
            seq_start_loc = seq_start_loc_right,
            context_lens_tensor = split_context_lens_tensor[1],
            max_decode_query_len = max_decode_query_len_right,
            _cached_prefill_metadata = None,
            _cached_decode_metadata = None,
            tree_attention_masks_tensor = None,
            block_tables_list = block_tables_list_right,
            encoder_seq_lens = encoder_seq_lens_right,
            encoder_seq_lens_tensor = encoder_seq_lens_tensor_right,
            max_encoder_seq_len = max_encoder_seq_len_right,
            num_encoder_tokens = num_encoder_tokens_right,
            cross_slot_mapping = cross_slot_mapping_right,
            cross_block_tables = cross_block_tables_right,
        )
        
    if isinstance(model_input.attn_metadata, FlashMLAMetadata):
        attn_metadata_left = FlashMLAMetadata(
            num_prefills = num_prefills_left,
            num_prefill_tokens = num_prefill_tokens_left,
            num_decode_tokens = num_decode_tokens_left,
            slot_mapping = split_slot_mapping[0],
            multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
            enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
            use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
            input_positions = split_input_positions[0],
            seq_lens = seq_lens_left,
            seq_lens_tensor = split_seq_lens_tensor[0],
            max_prefill_seq_len = max_prefill_seq_len_left,
            max_decode_seq_len = max_decode_seq_len_left,
            context_lens_tensor = split_context_lens_tensor[0],
            block_tables = split_block_tables[0],
            max_query_len = max_query_len_left,
            max_decode_query_len = max_decode_query_len_left,
            query_start_loc = query_start_loc_left,
            seq_start_loc = seq_start_loc_left,
            _cached_prefill_metadata = None,
            _cached_decode_metadata = None,
            head_dim = model_input.attn_metadata.head_dim,
            is_profile_run = model_input.attn_metadata.is_profile_run,
            context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens, 
            context_chunk_starts=model_input.attn_metadata.context_chunk_starts, 
            context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot, 
            context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens, 
            context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace, 
            decode_tile_scheduler_metadata=model_input.attn_metadata.decode_tile_scheduler_metadata, 
            decode_num_splits=model_input.attn_metadata.decode_num_splits
        )
        attn_metadata_right = FlashMLAMetadata(
            num_prefills = num_prefills_right,
            num_prefill_tokens = num_prefill_tokens_right,
            num_decode_tokens = num_decode_tokens_right,
            slot_mapping = split_slot_mapping[1],
            multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
            enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
            use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
            input_positions = split_input_positions[1],
            seq_lens = seq_lens_right,
            seq_lens_tensor = split_seq_lens_tensor[1],
            max_prefill_seq_len = max_prefill_seq_len_right,
            max_decode_seq_len = max_decode_seq_len_right,
            context_lens_tensor = split_context_lens_tensor[1],
            block_tables = split_block_tables[1],
            max_query_len = max_query_len_right,
            max_decode_query_len = max_decode_query_len_right,
            query_start_loc = query_start_loc_right,
            seq_start_loc = seq_start_loc_right,
            _cached_prefill_metadata = None,
            _cached_decode_metadata = None,
            head_dim = model_input.attn_metadata.head_dim,
            is_profile_run = model_input.attn_metadata.is_profile_run,
            context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens, 
            context_chunk_starts=model_input.attn_metadata.context_chunk_starts, 
            context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot, 
            context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens, 
            context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace, 
            decode_tile_scheduler_metadata=model_input.attn_metadata.decode_tile_scheduler_metadata, 
            decode_num_splits=model_input.attn_metadata.decode_num_splits
        )
        
    model_input_left = ModelInputForGPUWithSamplingMetadata(
        input_tokens=split_input_tokens[0],
        input_positions=split_input_positions[0],
        token_types=None,
        seq_lens=seq_lens_left,
        query_lens=query_lens_left,
        lora_mapping=model_input.lora_mapping,
        lora_requests=model_input.lora_requests,
        attn_metadata=attn_metadata_left,
        prompt_adapter_mapping=model_input.prompt_adapter_mapping,
        prompt_adapter_requests=model_input.prompt_adapter_requests,
        multi_modal_kwargs=model_input.multi_modal_kwargs,
        request_ids_to_seq_ids=request_ids_to_seq_ids_left,
        finished_requests_ids=model_input.finished_requests_ids,
        virtual_engine=model_input.virtual_engine,
        async_callback=model_input.async_callback,
        scheduler_outputs=model_input.scheduler_outputs,
        previous_hidden_states=model_input.previous_hidden_states,
        sampling_metadata=SamplingMetadata(
            seq_groups=seq_groups_left,
            selected_token_indices=selected_token_indices_left,
            categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
            num_prompts=num_prefills_left,
            skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
            reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
        ),
        is_prompt=model_input.is_prompt,
    )
    model_input_right = ModelInputForGPUWithSamplingMetadata(
        input_tokens=split_input_tokens[1],
        input_positions=split_input_positions[1],
        token_types=None,
        seq_lens=seq_lens_right,
        query_lens=query_lens_right,
        lora_mapping=model_input.lora_mapping,
        lora_requests=model_input.lora_requests,
        attn_metadata=attn_metadata_right,
        prompt_adapter_mapping=model_input.prompt_adapter_mapping,
        prompt_adapter_requests=model_input.prompt_adapter_requests,
        multi_modal_kwargs=model_input.multi_modal_kwargs,
        request_ids_to_seq_ids=request_ids_to_seq_ids_right,
        finished_requests_ids=model_input.finished_requests_ids,
        virtual_engine=model_input.virtual_engine,
        async_callback=model_input.async_callback,
        scheduler_outputs=model_input.scheduler_outputs,
        previous_hidden_states=model_input.previous_hidden_states,
        sampling_metadata=SamplingMetadata(
            seq_groups=seq_groups_right,
            selected_token_indices=selected_token_indices_right,
            categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
            num_prompts=num_prefills_right,
            skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
            reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
        ),
        is_prompt=model_input.is_prompt,
    )
    return model_input_left, model_input_right