two_batch_overlap.py 20.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465

import os
import queue
import threading
import torch
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.utils import async_tensor_h2d
from vllm.logger import init_logger
from vllm.profiler.prof import profile

enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'

enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

def is_enable_tbo():
    return enable_tbo

class TwoBatchOverlap():
    def __init__(self):
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.all_reduce_queue = queue.Queue()
        self.all_reduce_out = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
        self.stream = torch.cuda.Stream()
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.left_thread.start()
        self.right_thread.start()

    def finish_thread(self):
        if self.left_thread != None:
            self.model_input_left_queue.put(None)
            self.left_thread.join()
            self.left_thread = None
        if self.right_thread != None:
            self.model_input_right_queue.put(None)
            self.right_thread.join()
            self.right_thread = None
        logger.info('tbo:finish threads')
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
        if queue == self.model_input_left_queue:
            self.left_tid = threading.get_ident()
            is_left_thread = True
            logger.info('tbo:new thread %d', self.left_tid)
            init_tbo_forward_context(True, self.left_tid)
        else:
            self.right_tid = threading.get_ident()
            logger.info('tbo:new thread %d', self.right_tid)
            init_tbo_forward_context(False, self.right_tid)
        while True:
            model_input = queue.get()
            if model_input == None:
                break
            profile.ProfRangePush('start')
            self.tbo_thread_synchronize(False)
            with set_forward_context(model_input.attn_metadata,
                                    self.vllm_config, self.virtual_engine):
                hidden_or_intermediate_states = self.model_executable(
                    input_ids=model_input.input_tokens,
                    positions=model_input.input_positions,
                    intermediate_tensors=self.intermediate_tensors,
                    **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                    device=self.self_device),
                    **self.seqlen_agnostic_kwargs,
                    **self.model_kwargs,
                )
            profile.ProfRangePush('end')
            if is_left_thread:
                self.sem_right.release()
                self.states_left_queue.put(hidden_or_intermediate_states)
            else:
                self.all_reduce_queue.put(None)
                self.states_right_queue.put(hidden_or_intermediate_states)

    def tbo_thread_synchronize(self, recode_flag = True):
        tid = threading.get_ident()
        if tid == self.left_tid:
            if recode_flag and not tbo_one_stream:
                print('###left_c2t_recorded')
                self.event_left_c2t.record()
            if not self.left_first:
                self.sem_right.release()
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            self.left_first = False
            return self.event_left_c2t, self.event_left_t2c
        else:
            if recode_flag and not tbo_one_stream:
                print('###right_c2t_recorded')
                self.event_right_c2t.record()
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
                        intermediate_tensors,
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
                        model_kwargs):
        if self.left_thread == None:
            self.init_tbo_thread()
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
        self.intermediate_tensors = intermediate_tensors
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
        self.model_kwargs = model_kwargs
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
    def all_reduce(self):
        while True:
            obj = self.all_reduce_queue.get()
            if obj == None:
                break
            buf, event_c2t, event_t2c = obj
            #print('###buf', buf[0,0:5])
            if tbo_one_stream:
                output = tensor_model_parallel_all_reduce(buf)
            else:
                with torch.cuda.stream(self.stream):
                    print('###stream.wait_event event_c2t before all_reduce')
                    self.stream.wait_event(event_c2t)
                    output = tensor_model_parallel_all_reduce(buf)
                    print('###event_t2c recorded')
                    event_t2c.record()
            #print('###print', output[0,0:5])
            self.all_reduce_out.put(output)

tbo_obj = None

def init_two_batch_overlap():
    if enable_tbo:
        global tbo_obj
        if tbo_obj == None:
            tbo_obj = TwoBatchOverlap()

def finish_two_batch_overlap():
    global tbo_obj
    if tbo_obj != None:
        tbo_obj.finish_thread()
        tbo_obj = None

def tbo_all_reduce(obj):
    if enable_tbo and tbo_obj != None and tbo_obj.tbo_running:
        event_c2t, event_t2c = tbo_obj.tbo_thread_synchronize()
        tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
        output = tbo_obj.all_reduce_out.get()
        if not tbo_one_stream:
            current_stream = torch.cuda.current_stream()
            print('###current_stream wait event event_t2c')
            current_stream.wait_event(event_t2c)
        return output
    return tensor_model_parallel_all_reduce(obj) 

def cumsum(lst):
    cum_lst = [0]
    sum = 0
    for i in range(0, len(lst)):
        sum = sum + lst[i]
        cum_lst.append(sum)
    return cum_lst

def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
    query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
    batch_size_split = [batch_size_left, batch_size_right]
    split_input_tokens = torch.split(model_input.input_tokens, query_tokens_split, dim=0)
    split_input_positions = torch.split(model_input.input_positions, query_tokens_split, dim=0)
    seq_lens_left = model_input.attn_metadata.seq_lens[0:batch_size_left]
    seq_lens_right = model_input.attn_metadata.seq_lens[batch_size_left:]
    query_lens_left = model_input.query_lens[0:batch_size_left]
    query_lens_right = model_input.query_lens[batch_size_left:]
    split_seq_lens_tensor = torch.split(model_input.attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
    split_block_tables = torch.split(model_input.attn_metadata.block_tables, batch_size_split, dim=0)
    num_prefills_left = 0
    num_prefills_right = 0
    num_prefill_tokens_left = 0
    num_prefill_tokens_right = 0
    num_decode_tokens_left = 0
    num_decode_tokens_right = 0
    max_prefill_seq_len_left = 0
    max_prefill_seq_len_right = 0
    max_decode_seq_len_left = 0
    max_decode_seq_len_right = 0
    max_decode_query_len_left = None
    max_decode_query_len_right = None
    encoder_seq_lens_left = None
    encoder_seq_lens_right = None
    encoder_seq_lens_tensor_left = None
    encoder_seq_lens_tensor_right = None
    max_encoder_seq_len_left = None
    max_encoder_seq_len_right = None
    num_encoder_tokens_left = None
    num_encoder_tokens_right = None
    cross_slot_mapping_left = None
    cross_slot_mapping_right = None
    cross_block_tables_left = None
    cross_block_tables_right = None
    if model_input.is_prompt:
        num_prefills_left = batch_size_left
        num_prefills_right = batch_size_right
        num_prefill_tokens_left = sum(model_input.query_lens[0:batch_size_left])
        num_prefill_tokens_right = sum(model_input.query_lens[batch_size_left:])
        max_prefill_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
        max_prefill_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
    else:
        num_decode_tokens_left = batch_size_left
        num_decode_tokens_right = batch_size_right
        max_decode_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
        max_decode_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
    split_slot_mapping = torch.split(model_input.attn_metadata.slot_mapping, query_tokens_split, dim=0)
    max_query_len_left = max(model_input.query_lens[0:batch_size_left])
    max_query_len_right = max(model_input.query_lens[batch_size_left:])
    zero_tensor = torch.tensor([0], device=self_device, dtype=torch.int32)
    query_start_loc_left_list = cumsum(query_lens_left)
    query_start_loc_right_list = cumsum(query_lens_right)
    query_start_loc_left = async_tensor_h2d(query_start_loc_left_list, torch.int32,
                                            self_device,
                                            True)
    query_start_loc_right = async_tensor_h2d(query_start_loc_right_list, torch.int32,
                                            self_device,
                                            True)
    seq_start_loc_left = torch.cat((zero_tensor, split_seq_lens_tensor[0].cumsum(dim=0)), dim=0).to(torch.int32)
    seq_start_loc_right = torch.cat((zero_tensor, split_seq_lens_tensor[1].cumsum(dim=0)), dim=0).to(torch.int32)

    split_context_lens_tensor = torch.split(model_input.attn_metadata.context_lens_tensor, batch_size_split, dim=0)
    block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
    block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
    request_ids_to_seq_ids_left = {}
    request_ids_to_seq_ids_right = {}
    counter = 0
    for key, value in model_input.request_ids_to_seq_ids.items():
        if counter < batch_size_left:
            request_ids_to_seq_ids_left[key] = value
        else:
            request_ids_to_seq_ids_right[key] = value
        counter += 1
    seq_groups_left = None
    seq_groups_right = None
    if model_input.sampling_metadata.seq_groups is not None:
        seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left]
        seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
    selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
    selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
    from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
    attn_metadata_left = ROCmFlashAttentionMetadata(
        seq_lens_tensor = split_seq_lens_tensor[0],
        max_decode_seq_len = max_decode_seq_len_left,
        block_tables = split_block_tables[0],
        num_prefills = num_prefills_left,
        num_prefill_tokens = num_prefill_tokens_left,
        num_decode_tokens = num_decode_tokens_left,
        slot_mapping = split_slot_mapping[0],
        multi_modal_placeholder_index_maps = {},
        enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
        seq_lens = seq_lens_left,
        max_prefill_seq_len = max_prefill_seq_len_left,
        use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
        max_query_len = max_query_len_left,
        query_start_loc = query_start_loc_left,
        seq_start_loc = seq_start_loc_left,
        context_lens_tensor = split_context_lens_tensor[0],
        max_decode_query_len = max_decode_query_len_left,
        _cached_prefill_metadata = None,
        _cached_decode_metadata = None,
        tree_attention_masks_tensor = None,
        block_tables_list = block_tables_list_left,
        encoder_seq_lens = encoder_seq_lens_left,
        encoder_seq_lens_tensor = encoder_seq_lens_tensor_left,
        max_encoder_seq_len = max_encoder_seq_len_left,
        num_encoder_tokens = num_encoder_tokens_left,
        cross_slot_mapping = cross_slot_mapping_left,
        cross_block_tables = cross_block_tables_left,
    )
    model_input_left = ModelInputForGPUWithSamplingMetadata(
        input_tokens=split_input_tokens[0],
        input_positions=split_input_positions[0],
        token_types=None,
        seq_lens=seq_lens_left,
        query_lens=query_lens_left,
        lora_mapping=model_input.lora_mapping,
        lora_requests=model_input.lora_requests,
        attn_metadata=attn_metadata_left,
        prompt_adapter_mapping=model_input.prompt_adapter_mapping,
        prompt_adapter_requests=model_input.prompt_adapter_requests,
        multi_modal_kwargs=model_input.multi_modal_kwargs,
        request_ids_to_seq_ids=request_ids_to_seq_ids_left,
        finished_requests_ids=model_input.finished_requests_ids,
        virtual_engine=model_input.virtual_engine,
        async_callback=model_input.async_callback,
        scheduler_outputs=model_input.scheduler_outputs,
        previous_hidden_states=model_input.previous_hidden_states,
        sampling_metadata=SamplingMetadata(
            seq_groups=seq_groups_left,
            selected_token_indices=selected_token_indices_left,
            categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
            num_prompts=num_prefills_left,
            skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
            reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
        ),
        is_prompt=model_input.is_prompt,
    )
    attn_metadata_right = ROCmFlashAttentionMetadata(
        seq_lens_tensor = split_seq_lens_tensor[1],
        max_decode_seq_len = max_decode_seq_len_right,
        block_tables = split_block_tables[1],
        num_prefills = num_prefills_right,
        num_prefill_tokens = num_prefill_tokens_right,
        num_decode_tokens = num_decode_tokens_right,
        slot_mapping = split_slot_mapping[1],
        multi_modal_placeholder_index_maps = {},
        enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
        seq_lens = seq_lens_right,
        max_prefill_seq_len = max_prefill_seq_len_right,
        use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
        max_query_len = max_query_len_right,
        query_start_loc = query_start_loc_right,
        seq_start_loc = seq_start_loc_right,
        context_lens_tensor = split_context_lens_tensor[1],
        max_decode_query_len = max_decode_query_len_right,
        _cached_prefill_metadata = None,
        _cached_decode_metadata = None,
        tree_attention_masks_tensor = None,
        block_tables_list = block_tables_list_right,
        encoder_seq_lens = encoder_seq_lens_right,
        encoder_seq_lens_tensor = encoder_seq_lens_tensor_right,
        max_encoder_seq_len = max_encoder_seq_len_right,
        num_encoder_tokens = num_encoder_tokens_right,
        cross_slot_mapping = cross_slot_mapping_right,
        cross_block_tables = cross_block_tables_right,
    )
    model_input_right = ModelInputForGPUWithSamplingMetadata(
        input_tokens=split_input_tokens[1],
        input_positions=split_input_positions[1],
        token_types=None,
        seq_lens=seq_lens_right,
        query_lens=query_lens_right,
        lora_mapping=model_input.lora_mapping,
        lora_requests=model_input.lora_requests,
        attn_metadata=attn_metadata_right,
        prompt_adapter_mapping=model_input.prompt_adapter_mapping,
        prompt_adapter_requests=model_input.prompt_adapter_requests,
        multi_modal_kwargs=model_input.multi_modal_kwargs,
        request_ids_to_seq_ids=request_ids_to_seq_ids_right,
        finished_requests_ids=model_input.finished_requests_ids,
        virtual_engine=model_input.virtual_engine,
        async_callback=model_input.async_callback,
        scheduler_outputs=model_input.scheduler_outputs,
        previous_hidden_states=model_input.previous_hidden_states,
        sampling_metadata=SamplingMetadata(
            seq_groups=seq_groups_right,
            selected_token_indices=selected_token_indices_right,
            categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
            num_prompts=num_prefills_right,
            skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
            reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
        ),
        is_prompt=model_input.is_prompt,
    )
    return model_input_left, model_input_right

def merge_model_output(states_left, states_right):
    output = torch.concat([states_left, states_right], dim=0)
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
    init_two_batch_overlap()
    is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
    is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt    
    batch_size = len(model_input.attn_metadata.seq_lens)
    if batch_size == 1 or \
        (not model_input.is_prompt and not enable_tbo_decode) or \
        not is_rocm_fa or \
        is_cuda_graph_decode:
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
    tbo_obj.set_model_input(model_input_left, 
                            model_input_right, 
                            vllm_config,
                            virtual_engine,
                            model_executable,
                            intermediate_tensors,
                            multi_modal_kwargs,
                            self_device,
                            seqlen_agnostic_kwargs,
                            model_kwargs)
    tbo_obj.all_reduce()
    states_left, states_right = tbo_obj.get_model_output()

    hidden_or_intermediate_states = merge_model_output(states_left, states_right)
    tbo_obj.tbo_running = False
    return hidden_or_intermediate_states