two_batch_overlap.py 21.1 KB
Newer Older
1
import gc
2
3
4
import os
import queue
import threading
5
from typing import List, Optional, Tuple
6
import torch
7
from vllm.attention.backends.abstract import AttentionMetadata
8
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
9
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
10
11
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
12
from vllm.sequence import IntermediateTensors
13
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
14
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_capture_attention_metadata, split_model_input
15
16
from vllm.logger import init_logger
from vllm.profiler.prof import profile
17
from vllm import envs
18
from vllm.utils import weak_ref_tensor
19
20
21
22
23

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

24
25
26
tbo_step_stream = None
all_reduce_stream = None

27
28
class TwoBatchOverlap():
    def __init__(self):
29
30
        global tbo_step_stream
        global all_reduce_stream
31
32
33
34
35
36
37
38
39
40
41
42
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
43
        self.tbo_in_capture = False
44
45
46
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
47
        self.step_event = torch.cuda.Event(enable_timing=False)
48
49
50
51
52
53
54
55
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
56
57
58
59
60
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.left_thread.start()
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.right_thread.start()
        logger.info('tbo:two batch overlap start')
61
62

    def finish_thread(self):
63
64
65
66
        self.left_thread.join()
        self.left_thread = None
        self.right_thread.join()
        self.right_thread = None
67
68
69
70
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
71
        tid = threading.get_ident()
72
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
73
            self.left_tid = tid
74
75
76
            is_left_thread = True
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
77
            self.right_tid = tid
78
            init_tbo_forward_context(False, self.right_tid)
79
        with torch.cuda.stream(tbo_step_stream):
80
81
82
83
84
85
86
87
88
89
90
            model_input = queue.get()
            profile.ProfRangePush('start')
            self.tbo_thread_synchronize(tid)
            model_kwargs = None
            intermediate_tensors = None
            if is_left_thread:
                model_kwargs = self.model_kwargs_left
                intermediate_tensors = self.intermediate_tensors_left
            else:
                model_kwargs = self.model_kwargs_right
                intermediate_tensors = self.intermediate_tensors_right
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            hidden_or_intermediate_states = None
            if self.tbo_in_capture:
                if is_left_thread:
                    attn_metadata = self.attn_metadata_left
                    input_tokens = self.input_tokens_left
                    input_positions = self.split_input_positions[0]
                else:
                    attn_metadata = self.attn_metadata_right
                    input_tokens = self.input_tokens_right
                    input_positions = self.split_input_positions[1]
                with set_forward_context(attn_metadata,
                                        self.vllm_config, self.virtual_engine):
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=input_tokens,
                        positions=input_positions,
                        intermediate_tensors=intermediate_tensors,
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **model_kwargs,
                    )
            elif model_input != None:
                with set_forward_context(model_input.attn_metadata,
                                        self.vllm_config, self.virtual_engine):
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=model_input.input_tokens,
                        positions=model_input.input_positions,
                        intermediate_tensors=intermediate_tensors,
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **self.seqlen_agnostic_kwargs,
                        **model_kwargs,
                    )
123
124
125
126
127
128
            if is_left_thread:
                self.sem_right.release()
                self.states_left_queue.put(hidden_or_intermediate_states)
            else:
                self.states_right_queue.put(hidden_or_intermediate_states)
            profile.ProfRangePop()
129

lizhigong's avatar
lizhigong committed
130
    def tbo_thread_synchronize(self, tid):
131
132
133
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
134
            self.left_first = False
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
152
153
                        intermediate_tensors_left,
                        intermediate_tensors_right,
154
155
156
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
157
158
                        model_kwargs_left,
                        model_kwargs_right):
159
160
161
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
162
163
        self.intermediate_tensors_left = intermediate_tensors_left
        self.intermediate_tensors_right = intermediate_tensors_right
164
165
166
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
lizhigong's avatar
lizhigong committed
167
168
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
169
170
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        
    def set_capture_model_input(self,
                                input_tokens_left, 
                                input_tokens_right, 
                                split_input_positions, 
                                vllm_config,
                                virtual_engine,
                                runner_model,
                                runner_device,
                                intermediate_tensors_left,
                                intermediate_tensors_right,
                                model_kwargs_left,
                                model_kwargs_right,
                                attn_metadata_left, 
                                attn_metadata_right):
        self.input_tokens_left  = input_tokens_left
        self.input_tokens_right  = input_tokens_right
        self.split_input_positions = split_input_positions
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = runner_model
        self.self_device = runner_device
        self.intermediate_tensors_left = intermediate_tensors_left
        self.intermediate_tensors_right = intermediate_tensors_right
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
        self.attn_metadata_left = attn_metadata_left
        self.attn_metadata_right = attn_metadata_right
        self.model_input_left_queue.put(None)
        self.model_input_right_queue.put(None)

202
203
204
205
206
207
208
209
210
211

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
tbo_obj = None

def init_two_batch_overlap():
    global tbo_obj
212
213
214
    if tbo_obj == None:
        tbo_obj = TwoBatchOverlap()
    tbo_obj.init_tbo_thread()
215
216

def tbo_all_reduce(obj):
217
    if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
218
219
220
221
222
223
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
224
225
226
227
228
229
            event_c2t.record()
            with torch.cuda.stream(all_reduce_stream):
                all_reduce_stream.wait_event(event_c2t)
                output = tensor_model_parallel_all_reduce(obj)
                event_t2c.record()
            tbo_obj.tbo_thread_synchronize(tid)
230
            tbo_step_stream.wait_event(event_t2c)
231
232
233
        else:
            output = tensor_model_parallel_all_reduce(obj)
            tbo_obj.tbo_thread_synchronize(tid)
234
235
236
237
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
238
239
240
241
242
243
244
    if isinstance(states_left, IntermediateTensors):
        output_map = {}
        for key in states_left.tensors:
            output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
        output = IntermediateTensors(output_map)
    else:
        output = torch.concat([states_left, states_right], dim=0)
245
246
247
248
249
250
251
252
253
254
255
256
257
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
258
259
260
    is_support = is_supported_attention_metadata(model_input.attn_metadata)
    if not is_support:
        logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
261
    batch_size = len(model_input.attn_metadata.seq_lens)
262
263
264
265
    is_decode_tbo_invalid = not model_input.is_prompt and (
        envs.VLLM_TBO_DECODE_BS < 2 or 
        batch_size < envs.VLLM_TBO_DECODE_BS or 
        model_input.attn_metadata.use_cuda_graph)
266
    if batch_size == 1 or \
267
268
        is_decode_tbo_invalid or \
        not is_support:
269
270
271
272
273
274
275
276
277
278
279
280
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
281
    profile.ProfRangePush('tbo_model_executable')
282
    init_two_batch_overlap()
283
284
285
286
287
288
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
289
    
290
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
291
    
lizhigong's avatar
lizhigong committed
292
293
    model_kwargs_left = model_kwargs.copy()
    model_kwargs_right = model_kwargs.copy()
294
295
    intermediate_tensors_left = None
    intermediate_tensors_right = None
lizhigong's avatar
lizhigong committed
296
297
    if "previous_hidden_states" in model_kwargs:
        previous_hidden_states = model_kwargs["previous_hidden_states"]
lizhigong's avatar
lizhigong committed
298
299
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
lizhigong's avatar
lizhigong committed
300
301
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
302
303
304
305
306
307
308
309
310
311
    if intermediate_tensors != None:
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        intermediate_tensors_left = {}
        intermediate_tensors_right = {}
        for key in intermediate_tensors.tensors:
            split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
            intermediate_tensors_left[key] = split_intermediate_tensors[0]
            intermediate_tensors_right[key] = split_intermediate_tensors[1]
        intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
        intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
lizhigong's avatar
lizhigong committed
312

lizhigong's avatar
lizhigong committed
313
314
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
315
316
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
317
318
319
320
321
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
322
323
                                intermediate_tensors_left,
                                intermediate_tensors_right,
lizhigong's avatar
lizhigong committed
324
325
326
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
327
328
                                model_kwargs_left,
                                model_kwargs_right)
329
        
lizhigong's avatar
lizhigong committed
330
        states_left, states_right = tbo_obj.get_model_output()
331

lizhigong's avatar
lizhigong committed
332
333
334
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
335
        tbo_obj.finish_thread()
lizhigong's avatar
lizhigong committed
336
337
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
338
    return hidden_or_intermediate_states
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

def _run_once(vllm_config, virtual_engine,
        runner,
        self_device,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_inputs: Optional[IntermediateTensors],
        attn_metadata: AttentionMetadata,
        stream: torch.cuda.Stream,
        **kwargs):
    global tbo_step_stream
    stream_back = tbo_step_stream
    tbo_step_stream = stream
    init_two_batch_overlap()
    tbo_obj.left_first = True
    decode_batch_size = input_ids.shape[0]
    batch_size_left = int(decode_batch_size / 2)
    batch_size_right = decode_batch_size - batch_size_left
    query_tokens_split = [batch_size_left, batch_size_right]
    input_tokens_left, input_tokens_right = torch.split(input_ids, query_tokens_split, dim=0)
    split_input_positions = torch.split(positions, query_tokens_split, dim=0)
    model_kwargs_left = kwargs.copy()
    model_kwargs_right = kwargs.copy()
    intermediate_tensors_left = None
    intermediate_tensors_right = None
    if "previous_hidden_states" in kwargs:
        previous_hidden_states = kwargs["previous_hidden_states"]
        split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
    if intermediate_inputs != None:
        query_tokens_split = [batch_size_left, batch_size_right]
        intermediate_tensors_left = {}
        intermediate_tensors_right = {}
        for key in intermediate_inputs.tensors:
            split_intermediate_tensors = torch.split(intermediate_inputs.tensors[key], query_tokens_split, dim=0)
            intermediate_tensors_left[key] = split_intermediate_tensors[0]
            intermediate_tensors_right[key] = split_intermediate_tensors[1]
        intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
        intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
    attn_metadata_left, attn_metadata_right = split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right)
    tbo_obj.tbo_running = True
    tbo_obj.tbo_in_capture = True
    tbo_obj.set_capture_model_input(input_tokens_left, 
                                    input_tokens_right, 
                                    split_input_positions, 
                                    vllm_config,
                                    virtual_engine,
                                    runner.model,
                                    self_device,
                                    intermediate_tensors_left,
                                    intermediate_tensors_right,
                                    model_kwargs_left,
                                    model_kwargs_right, 
                                    attn_metadata_left, 
                                    attn_metadata_right)

    states_left, states_right = tbo_obj.get_model_output()
    output_hidden_or_intermediate_states = merge_model_output(states_left, states_right)
    tbo_obj.tbo_in_capture = False
    tbo_obj.tbo_running = False
    tbo_obj.finish_thread()
    tbo_step_stream = stream_back
    return output_hidden_or_intermediate_states

def tbo_capture(vllm_config, virtual_engine, _NUM_WARMUP_ITERS, 
        runner,
        self_device,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_inputs: Optional[IntermediateTensors],
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
        **kwargs):
    for i in range(_NUM_WARMUP_ITERS):
        _run_once(vllm_config, 
                    virtual_engine,
                    runner,
                    self_device,
                    input_ids,
                    positions,
                    intermediate_inputs,
                    attn_metadata,
                    torch.cuda.current_stream(),
                    **kwargs)
        torch.cuda.synchronize()
    runner._graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(runner._graph, pool=memory_pool, stream=stream):
        output_hidden_or_intermediate_states = _run_once(vllm_config, 
                                                            virtual_engine, 
                                                            runner,
                                                            self_device,
                                                            input_ids,
                                                            positions,
                                                            intermediate_inputs,
                                                            attn_metadata,
                                                            torch.cuda.current_stream(),
                                                            **kwargs)
        if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
            hidden_or_intermediate_states = weak_ref_tensor(
                output_hidden_or_intermediate_states)
        elif isinstance(output_hidden_or_intermediate_states,
                        IntermediateTensors):
            hidden_or_intermediate_states = IntermediateTensors(
                tensors={
                    key: weak_ref_tensor(value)
                    for key, value in
                    output_hidden_or_intermediate_states.tensors.items()
                })

        del output_hidden_or_intermediate_states
        # make sure `output_hidden_or_intermediate_states` is deleted
        # in the graph's memory pool
        gc.collect()
    torch.cuda.synchronize()
     
    # Save the input and output buffers.
    runner.input_buffers = {
        "input_ids":
        input_ids,
        "positions":
        positions,
        "kv_caches":
        kv_caches,
        **runner.attn_state.get_graph_input_buffers(
            attn_metadata, runner._is_encoder_decoder_model),
        **kwargs,
    }
    if intermediate_inputs is not None:
        runner.input_buffers.update(intermediate_inputs.tensors)
    if get_pp_group().is_last_rank:
        runner.output_buffers = {
            "hidden_states": hidden_or_intermediate_states
        }
    else:
        runner.output_buffers = hidden_or_intermediate_states