two_batch_overlap.py 21.3 KB
Newer Older
1
import gc
2
3
4
import os
import queue
import threading
5
from typing import List, Optional, Tuple
6
import torch
7
from vllm.attention.backends.abstract import AttentionMetadata
8
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
9
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
10
11
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
12
from vllm.sequence import IntermediateTensors
13
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
14
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_capture_attention_metadata, split_model_input
15
16
from vllm.logger import init_logger
from vllm.profiler.prof import profile
17
from vllm import envs
18
from vllm.utils import weak_ref_tensor
zhuwenwen's avatar
zhuwenwen committed
19
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import is_enable_tbo_v1, tbo_all_reduce_v1
20
21
22
23
24

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

25
26
27
tbo_step_stream = None
all_reduce_stream = None

28
29
class TwoBatchOverlap():
    def __init__(self):
30
31
        global tbo_step_stream
        global all_reduce_stream
32
33
34
35
36
37
38
39
40
41
42
43
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
44
        self.tbo_in_capture = False
45
46
47
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
48
        self.step_event = torch.cuda.Event(enable_timing=False)
49
50
51
52
53
54
55
56
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
57
58
59
60
61
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.left_thread.start()
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.right_thread.start()
        logger.info('tbo:two batch overlap start')
62
63

    def finish_thread(self):
64
65
66
67
        self.left_thread.join()
        self.left_thread = None
        self.right_thread.join()
        self.right_thread = None
68
69
70
71
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
72
        tid = threading.get_ident()
73
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
74
            self.left_tid = tid
75
76
77
            is_left_thread = True
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
78
            self.right_tid = tid
79
            init_tbo_forward_context(False, self.right_tid)
80
        with torch.cuda.stream(tbo_step_stream):
81
82
83
84
85
86
87
88
89
90
91
            model_input = queue.get()
            profile.ProfRangePush('start')
            self.tbo_thread_synchronize(tid)
            model_kwargs = None
            intermediate_tensors = None
            if is_left_thread:
                model_kwargs = self.model_kwargs_left
                intermediate_tensors = self.intermediate_tensors_left
            else:
                model_kwargs = self.model_kwargs_right
                intermediate_tensors = self.intermediate_tensors_right
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            hidden_or_intermediate_states = None
            if self.tbo_in_capture:
                if is_left_thread:
                    attn_metadata = self.attn_metadata_left
                    input_tokens = self.input_tokens_left
                    input_positions = self.split_input_positions[0]
                else:
                    attn_metadata = self.attn_metadata_right
                    input_tokens = self.input_tokens_right
                    input_positions = self.split_input_positions[1]
                with set_forward_context(attn_metadata,
                                        self.vllm_config, self.virtual_engine):
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=input_tokens,
                        positions=input_positions,
                        intermediate_tensors=intermediate_tensors,
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **model_kwargs,
                    )
            elif model_input != None:
                with set_forward_context(model_input.attn_metadata,
                                        self.vllm_config, self.virtual_engine):
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=model_input.input_tokens,
                        positions=model_input.input_positions,
                        intermediate_tensors=intermediate_tensors,
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **self.seqlen_agnostic_kwargs,
                        **model_kwargs,
                    )
124
125
126
127
128
129
            if is_left_thread:
                self.sem_right.release()
                self.states_left_queue.put(hidden_or_intermediate_states)
            else:
                self.states_right_queue.put(hidden_or_intermediate_states)
            profile.ProfRangePop()
130

lizhigong's avatar
lizhigong committed
131
    def tbo_thread_synchronize(self, tid):
132
133
134
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
135
            self.left_first = False
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
153
154
                        intermediate_tensors_left,
                        intermediate_tensors_right,
155
156
157
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
158
159
                        model_kwargs_left,
                        model_kwargs_right):
160
161
162
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
163
164
        self.intermediate_tensors_left = intermediate_tensors_left
        self.intermediate_tensors_right = intermediate_tensors_right
165
166
167
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
lizhigong's avatar
lizhigong committed
168
169
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
170
171
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        
    def set_capture_model_input(self,
                                input_tokens_left, 
                                input_tokens_right, 
                                split_input_positions, 
                                vllm_config,
                                virtual_engine,
                                runner_model,
                                runner_device,
                                intermediate_tensors_left,
                                intermediate_tensors_right,
                                model_kwargs_left,
                                model_kwargs_right,
                                attn_metadata_left, 
                                attn_metadata_right):
        self.input_tokens_left  = input_tokens_left
        self.input_tokens_right  = input_tokens_right
        self.split_input_positions = split_input_positions
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = runner_model
        self.self_device = runner_device
        self.intermediate_tensors_left = intermediate_tensors_left
        self.intermediate_tensors_right = intermediate_tensors_right
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
        self.attn_metadata_left = attn_metadata_left
        self.attn_metadata_right = attn_metadata_right
        self.model_input_left_queue.put(None)
        self.model_input_right_queue.put(None)

203
204
205
206
207
208
209
210
211
212

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
tbo_obj = None

def init_two_batch_overlap():
    global tbo_obj
213
214
215
    if tbo_obj == None:
        tbo_obj = TwoBatchOverlap()
    tbo_obj.init_tbo_thread()
216
217

def tbo_all_reduce(obj):
zhuwenwen's avatar
zhuwenwen committed
218
219
    if is_enable_tbo_v1():
        return tbo_all_reduce_v1(obj)
220
    if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
221
222
223
224
225
226
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
227
228
229
230
231
232
            event_c2t.record()
            with torch.cuda.stream(all_reduce_stream):
                all_reduce_stream.wait_event(event_c2t)
                output = tensor_model_parallel_all_reduce(obj)
                event_t2c.record()
            tbo_obj.tbo_thread_synchronize(tid)
233
            tbo_step_stream.wait_event(event_t2c)
234
235
236
        else:
            output = tensor_model_parallel_all_reduce(obj)
            tbo_obj.tbo_thread_synchronize(tid)
237
238
239
240
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
241
242
243
244
245
246
247
    if isinstance(states_left, IntermediateTensors):
        output_map = {}
        for key in states_left.tensors:
            output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
        output = IntermediateTensors(output_map)
    else:
        output = torch.concat([states_left, states_right], dim=0)
248
249
250
251
252
253
254
255
256
257
258
259
260
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
261
262
263
    is_support = is_supported_attention_metadata(model_input.attn_metadata)
    if not is_support:
        logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
264
    batch_size = len(model_input.attn_metadata.seq_lens)
265
266
267
268
    is_decode_tbo_invalid = not model_input.is_prompt and (
        envs.VLLM_TBO_DECODE_BS < 2 or 
        batch_size < envs.VLLM_TBO_DECODE_BS or 
        model_input.attn_metadata.use_cuda_graph)
269
    if batch_size == 1 or \
270
271
        is_decode_tbo_invalid or \
        not is_support:
272
273
274
275
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
lizhigong's avatar
lizhigong committed
276
                inputs_embeds=model_input.inputs_embeds,
277
278
279
280
281
282
283
284
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
285
    profile.ProfRangePush('tbo_model_executable')
286
    init_two_batch_overlap()
287
288
289
290
291
292
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
293
    
294
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
295
    
lizhigong's avatar
lizhigong committed
296
297
    model_kwargs_left = model_kwargs.copy()
    model_kwargs_right = model_kwargs.copy()
298
299
    intermediate_tensors_left = None
    intermediate_tensors_right = None
lizhigong's avatar
lizhigong committed
300
301
    if "previous_hidden_states" in model_kwargs:
        previous_hidden_states = model_kwargs["previous_hidden_states"]
lizhigong's avatar
lizhigong committed
302
303
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
lizhigong's avatar
lizhigong committed
304
305
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
306
307
308
309
310
311
312
313
314
315
    if intermediate_tensors != None:
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        intermediate_tensors_left = {}
        intermediate_tensors_right = {}
        for key in intermediate_tensors.tensors:
            split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
            intermediate_tensors_left[key] = split_intermediate_tensors[0]
            intermediate_tensors_right[key] = split_intermediate_tensors[1]
        intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
        intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
lizhigong's avatar
lizhigong committed
316

lizhigong's avatar
lizhigong committed
317
318
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
319
320
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
321
322
323
324
325
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
326
327
                                intermediate_tensors_left,
                                intermediate_tensors_right,
lizhigong's avatar
lizhigong committed
328
329
330
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
331
332
                                model_kwargs_left,
                                model_kwargs_right)
333
        
lizhigong's avatar
lizhigong committed
334
        states_left, states_right = tbo_obj.get_model_output()
335

lizhigong's avatar
lizhigong committed
336
337
338
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
339
        tbo_obj.finish_thread()
lizhigong's avatar
lizhigong committed
340
341
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
342
    return hidden_or_intermediate_states
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

def _run_once(vllm_config, virtual_engine,
        runner,
        self_device,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_inputs: Optional[IntermediateTensors],
        attn_metadata: AttentionMetadata,
        stream: torch.cuda.Stream,
        **kwargs):
    global tbo_step_stream
    stream_back = tbo_step_stream
    tbo_step_stream = stream
    init_two_batch_overlap()
    tbo_obj.left_first = True
    decode_batch_size = input_ids.shape[0]
    batch_size_left = int(decode_batch_size / 2)
    batch_size_right = decode_batch_size - batch_size_left
    query_tokens_split = [batch_size_left, batch_size_right]
    input_tokens_left, input_tokens_right = torch.split(input_ids, query_tokens_split, dim=0)
    split_input_positions = torch.split(positions, query_tokens_split, dim=0)
    model_kwargs_left = kwargs.copy()
    model_kwargs_right = kwargs.copy()
    intermediate_tensors_left = None
    intermediate_tensors_right = None
    if "previous_hidden_states" in kwargs:
        previous_hidden_states = kwargs["previous_hidden_states"]
        split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
    if intermediate_inputs != None:
        query_tokens_split = [batch_size_left, batch_size_right]
        intermediate_tensors_left = {}
        intermediate_tensors_right = {}
        for key in intermediate_inputs.tensors:
            split_intermediate_tensors = torch.split(intermediate_inputs.tensors[key], query_tokens_split, dim=0)
            intermediate_tensors_left[key] = split_intermediate_tensors[0]
            intermediate_tensors_right[key] = split_intermediate_tensors[1]
        intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
        intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
    attn_metadata_left, attn_metadata_right = split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right)
    tbo_obj.tbo_running = True
    tbo_obj.tbo_in_capture = True
    tbo_obj.set_capture_model_input(input_tokens_left, 
                                    input_tokens_right, 
                                    split_input_positions, 
                                    vllm_config,
                                    virtual_engine,
                                    runner.model,
                                    self_device,
                                    intermediate_tensors_left,
                                    intermediate_tensors_right,
                                    model_kwargs_left,
                                    model_kwargs_right, 
                                    attn_metadata_left, 
                                    attn_metadata_right)

    states_left, states_right = tbo_obj.get_model_output()
    output_hidden_or_intermediate_states = merge_model_output(states_left, states_right)
    tbo_obj.tbo_in_capture = False
    tbo_obj.tbo_running = False
    tbo_obj.finish_thread()
    tbo_step_stream = stream_back
    return output_hidden_or_intermediate_states

def tbo_capture(vllm_config, virtual_engine, _NUM_WARMUP_ITERS, 
        runner,
        self_device,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_inputs: Optional[IntermediateTensors],
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
        **kwargs):
    for i in range(_NUM_WARMUP_ITERS):
        _run_once(vllm_config, 
                    virtual_engine,
                    runner,
                    self_device,
                    input_ids,
                    positions,
                    intermediate_inputs,
                    attn_metadata,
                    torch.cuda.current_stream(),
                    **kwargs)
        torch.cuda.synchronize()
    runner._graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(runner._graph, pool=memory_pool, stream=stream):
        output_hidden_or_intermediate_states = _run_once(vllm_config, 
                                                            virtual_engine, 
                                                            runner,
                                                            self_device,
                                                            input_ids,
                                                            positions,
                                                            intermediate_inputs,
                                                            attn_metadata,
                                                            torch.cuda.current_stream(),
                                                            **kwargs)
        if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
            hidden_or_intermediate_states = weak_ref_tensor(
                output_hidden_or_intermediate_states)
        elif isinstance(output_hidden_or_intermediate_states,
                        IntermediateTensors):
            hidden_or_intermediate_states = IntermediateTensors(
                tensors={
                    key: weak_ref_tensor(value)
                    for key, value in
                    output_hidden_or_intermediate_states.tensors.items()
                })

        del output_hidden_or_intermediate_states
        # make sure `output_hidden_or_intermediate_states` is deleted
        # in the graph's memory pool
        gc.collect()
    torch.cuda.synchronize()
     
    # Save the input and output buffers.
    runner.input_buffers = {
        "input_ids":
        input_ids,
        "positions":
        positions,
        "kv_caches":
        kv_caches,
        **runner.attn_state.get_graph_input_buffers(
            attn_metadata, runner._is_encoder_decoder_model),
        **kwargs,
    }
    if intermediate_inputs is not None:
        runner.input_buffers.update(intermediate_inputs.tensors)
    if get_pp_group().is_last_rank:
        runner.output_buffers = {
            "hidden_states": hidden_or_intermediate_states
        }
    else:
        runner.output_buffers = hidden_or_intermediate_states