two_batch_overlap.py 13.3 KB
Newer Older
1
2
3
4
5
6

import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
7
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
8
9
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
10
from vllm.sequence import IntermediateTensors
11
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
12
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
13
14
from vllm.logger import init_logger
from vllm.profiler.prof import profile
15
from vllm import envs
16
17
18
19
20
21
22

enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

23
24
25
tbo_step_stream = None
all_reduce_stream = None

26
27
class TwoBatchOverlap():
    def __init__(self):
28
29
        global tbo_step_stream
        global all_reduce_stream
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.all_reduce_queue = queue.Queue()
        self.all_reduce_out = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
44
45
46
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
47
        self.step_event = torch.cuda.Event(enable_timing=False)
48
49
50
51
52
53
54
55
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
lizhigong's avatar
lizhigong committed
56
57
58
59
60
61
        if self.left_thread == None:
            self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
            self.left_thread.start()
        if self.right_thread == None:
            self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
            self.right_thread.start()
62
        logger.info('tbo:two batch overlap threads start')
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    def finish_thread(self):
        if self.left_thread != None:
            self.model_input_left_queue.put(None)
            self.left_thread.join()
            self.left_thread = None
        if self.right_thread != None:
            self.model_input_right_queue.put(None)
            self.right_thread.join()
            self.right_thread = None
        logger.info('tbo:finish threads')
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
78
        tid = threading.get_ident()
79
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
80
            self.left_tid = tid
81
82
83
            is_left_thread = True
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
84
            self.right_tid = tid
85
            init_tbo_forward_context(False, self.right_tid)
86
        with torch.cuda.stream(tbo_step_stream):
lizhigong's avatar
lizhigong committed
87
88
89
90
91
92
            while True:
                model_input = queue.get()
                if model_input == None:
                    break
                profile.ProfRangePush('start')
                self.tbo_thread_synchronize(tid)
lizhigong's avatar
lizhigong committed
93
                model_kwargs = None
94
                intermediate_tensors = None
lizhigong's avatar
lizhigong committed
95
96
                if is_left_thread:
                    model_kwargs = self.model_kwargs_left
97
                    intermediate_tensors = self.intermediate_tensors_left
lizhigong's avatar
lizhigong committed
98
99
                else:
                    model_kwargs = self.model_kwargs_right
100
                    intermediate_tensors = self.intermediate_tensors_right
lizhigong's avatar
lizhigong committed
101
102
                with set_forward_context(model_input.attn_metadata,
                                        self.vllm_config, self.virtual_engine):
103
                    
lizhigong's avatar
lizhigong committed
104
105
106
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=model_input.input_tokens,
                        positions=model_input.input_positions,
107
                        intermediate_tensors=intermediate_tensors,
lizhigong's avatar
lizhigong committed
108
109
110
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **self.seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
111
                        **model_kwargs,
lizhigong's avatar
lizhigong committed
112
113
114
115
116
117
118
119
                    )
                if is_left_thread:
                    self.sem_right.release()
                    self.states_left_queue.put(hidden_or_intermediate_states)
                else:
                    self.all_reduce_queue.put(None)
                    self.states_right_queue.put(hidden_or_intermediate_states)
                profile.ProfRangePop()
120

lizhigong's avatar
lizhigong committed
121
    def tbo_thread_synchronize(self, tid):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            self.left_first = False
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
143
144
                        intermediate_tensors_left,
                        intermediate_tensors_right,
145
146
147
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
148
149
                        model_kwargs_left,
                        model_kwargs_right):
150
151
152
153
154
        if self.left_thread == None:
            self.init_tbo_thread()
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
155
156
        self.intermediate_tensors_left = intermediate_tensors_left
        self.intermediate_tensors_right = intermediate_tensors_right
157
158
159
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
lizhigong's avatar
lizhigong committed
160
161
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
    def all_reduce(self):
        while True:
            obj = self.all_reduce_queue.get()
            if obj == None:
                break
            buf, event_c2t, event_t2c = obj
            if tbo_one_stream:
                output = tensor_model_parallel_all_reduce(buf)
            else:
lizhigong's avatar
lizhigong committed
179
                event_c2t.record()
180
181
                with torch.cuda.stream(all_reduce_stream):
                    all_reduce_stream.wait_event(event_c2t)
182
183
184
185
186
187
188
                    output = tensor_model_parallel_all_reduce(buf)
                    event_t2c.record()
            self.all_reduce_out.put(output)

tbo_obj = None

def init_two_batch_overlap():
189
    if envs.VLLM_ENABLE_TBO:
190
191
192
193
194
195
196
197
198
199
200
        global tbo_obj
        if tbo_obj == None:
            tbo_obj = TwoBatchOverlap()

def finish_two_batch_overlap():
    global tbo_obj
    if tbo_obj != None:
        tbo_obj.finish_thread()
        tbo_obj = None

def tbo_all_reduce(obj):
201
    if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
202
203
204
205
206
207
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
208
209
        tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
        output = tbo_obj.all_reduce_out.get()
lizhigong's avatar
lizhigong committed
210
        tbo_obj.tbo_thread_synchronize(tid)
211
        if not tbo_one_stream:
212
            tbo_step_stream.wait_event(event_t2c)
213
214
215
216
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
217
218
219
220
221
222
223
    if isinstance(states_left, IntermediateTensors):
        output_map = {}
        for key in states_left.tensors:
            output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
        output = IntermediateTensors(output_map)
    else:
        output = torch.concat([states_left, states_right], dim=0)
224
225
226
227
228
229
230
231
232
233
234
235
236
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
237
238
239
    is_support = is_supported_attention_metadata(model_input.attn_metadata)
    if not is_support:
        logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
240
241
242
243
    is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt    
    batch_size = len(model_input.attn_metadata.seq_lens)
    if batch_size == 1 or \
        (not model_input.is_prompt and not enable_tbo_decode) or \
244
        not is_support or \
245
246
247
248
249
250
251
252
253
254
255
256
257
        is_cuda_graph_decode:
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
258
    profile.ProfRangePush('tbo_model_executable')
259
    init_two_batch_overlap()
260
261
262
263
264
265
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
266
    
267
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
268
    
lizhigong's avatar
lizhigong committed
269
270
    model_kwargs_left = model_kwargs.copy()
    model_kwargs_right = model_kwargs.copy()
271
272
    intermediate_tensors_left = None
    intermediate_tensors_right = None
lizhigong's avatar
lizhigong committed
273
274
    if "previous_hidden_states" in model_kwargs:
        previous_hidden_states = model_kwargs["previous_hidden_states"]
lizhigong's avatar
lizhigong committed
275
276
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
lizhigong's avatar
lizhigong committed
277
278
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
279
280
281
282
283
284
285
286
287
288
    if intermediate_tensors != None:
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        intermediate_tensors_left = {}
        intermediate_tensors_right = {}
        for key in intermediate_tensors.tensors:
            split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
            intermediate_tensors_left[key] = split_intermediate_tensors[0]
            intermediate_tensors_right[key] = split_intermediate_tensors[1]
        intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
        intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
lizhigong's avatar
lizhigong committed
289

lizhigong's avatar
lizhigong committed
290
291
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
292
293
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
294
295
296
297
298
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
299
300
                                intermediate_tensors_left,
                                intermediate_tensors_right,
lizhigong's avatar
lizhigong committed
301
302
303
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
304
305
                                model_kwargs_left,
                                model_kwargs_right)
lizhigong's avatar
lizhigong committed
306
307
        tbo_obj.all_reduce()
        states_left, states_right = tbo_obj.get_model_output()
308

lizhigong's avatar
lizhigong committed
309
310
311
312
313
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
314
    return hidden_or_intermediate_states