two_batch_overlap.py 12.5 KB
Newer Older
1
2
3
4
5
6

import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
7
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
8
9
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
10
from vllm.sequence import IntermediateTensors
11
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
12
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
13
14
from vllm.logger import init_logger
from vllm.profiler.prof import profile
15
from vllm import envs
16
17
18
19
20
21
22

enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

23
24
25
tbo_step_stream = None
all_reduce_stream = None

26
27
class TwoBatchOverlap():
    def __init__(self):
28
29
        global tbo_step_stream
        global all_reduce_stream
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.all_reduce_queue = queue.Queue()
        self.all_reduce_out = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
44
45
46
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
47
        self.step_event = torch.cuda.Event(enable_timing=False)
48
49
50
51
52
53
54
55
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
56
57
58
59
60
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.left_thread.start()
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.right_thread.start()
        logger.info('tbo:two batch overlap start')
61
62

    def finish_thread(self):
63
64
65
66
        self.left_thread.join()
        self.left_thread = None
        self.right_thread.join()
        self.right_thread = None
67
68
69
70
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
71
        tid = threading.get_ident()
72
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
73
            self.left_tid = tid
74
75
76
            is_left_thread = True
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
77
            self.right_tid = tid
78
            init_tbo_forward_context(False, self.right_tid)
79
        with torch.cuda.stream(tbo_step_stream):
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            model_input = queue.get()
            profile.ProfRangePush('start')
            self.tbo_thread_synchronize(tid)
            model_kwargs = None
            intermediate_tensors = None
            if is_left_thread:
                model_kwargs = self.model_kwargs_left
                intermediate_tensors = self.intermediate_tensors_left
            else:
                model_kwargs = self.model_kwargs_right
                intermediate_tensors = self.intermediate_tensors_right
            with set_forward_context(model_input.attn_metadata,
                                    self.vllm_config, self.virtual_engine):
                hidden_or_intermediate_states = self.model_executable(
                    input_ids=model_input.input_tokens,
                    positions=model_input.input_positions,
                    intermediate_tensors=intermediate_tensors,
                    **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                    device=self.self_device),
                    **self.seqlen_agnostic_kwargs,
                    **model_kwargs,
                )
            if is_left_thread:
                self.sem_right.release()
                self.states_left_queue.put(hidden_or_intermediate_states)
            else:
                self.all_reduce_queue.put(None)
                self.states_right_queue.put(hidden_or_intermediate_states)
            profile.ProfRangePop()
109

lizhigong's avatar
lizhigong committed
110
    def tbo_thread_synchronize(self, tid):
111
112
113
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
114
            self.left_first = False
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
132
133
                        intermediate_tensors_left,
                        intermediate_tensors_right,
134
135
136
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
137
138
                        model_kwargs_left,
                        model_kwargs_right):
139
140
141
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
142
143
        self.intermediate_tensors_left = intermediate_tensors_left
        self.intermediate_tensors_right = intermediate_tensors_right
144
145
146
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
lizhigong's avatar
lizhigong committed
147
148
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
    def all_reduce(self):
        while True:
            obj = self.all_reduce_queue.get()
            if obj == None:
                break
            buf, event_c2t, event_t2c = obj
            if tbo_one_stream:
                output = tensor_model_parallel_all_reduce(buf)
            else:
lizhigong's avatar
lizhigong committed
166
                event_c2t.record()
167
168
                with torch.cuda.stream(all_reduce_stream):
                    all_reduce_stream.wait_event(event_c2t)
169
170
171
172
173
174
175
176
                    output = tensor_model_parallel_all_reduce(buf)
                    event_t2c.record()
            self.all_reduce_out.put(output)

tbo_obj = None

def init_two_batch_overlap():
    global tbo_obj
177
178
179
    if tbo_obj == None:
        tbo_obj = TwoBatchOverlap()
    tbo_obj.init_tbo_thread()
180
181

def tbo_all_reduce(obj):
182
    if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
183
184
185
186
187
188
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
189
190
        tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
        output = tbo_obj.all_reduce_out.get()
lizhigong's avatar
lizhigong committed
191
        tbo_obj.tbo_thread_synchronize(tid)
192
        if not tbo_one_stream:
193
            tbo_step_stream.wait_event(event_t2c)
194
195
196
197
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
198
199
200
201
202
203
204
    if isinstance(states_left, IntermediateTensors):
        output_map = {}
        for key in states_left.tensors:
            output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
        output = IntermediateTensors(output_map)
    else:
        output = torch.concat([states_left, states_right], dim=0)
205
206
207
208
209
210
211
212
213
214
215
216
217
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
218
219
220
    is_support = is_supported_attention_metadata(model_input.attn_metadata)
    if not is_support:
        logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
221
222
223
224
    is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt    
    batch_size = len(model_input.attn_metadata.seq_lens)
    if batch_size == 1 or \
        (not model_input.is_prompt and not enable_tbo_decode) or \
225
        not is_support or \
226
227
228
229
230
231
232
233
234
235
236
237
238
        is_cuda_graph_decode:
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
239
    profile.ProfRangePush('tbo_model_executable')
240
    init_two_batch_overlap()
241
242
243
244
245
246
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
247
    
248
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
249
    
lizhigong's avatar
lizhigong committed
250
251
    model_kwargs_left = model_kwargs.copy()
    model_kwargs_right = model_kwargs.copy()
252
253
    intermediate_tensors_left = None
    intermediate_tensors_right = None
lizhigong's avatar
lizhigong committed
254
255
    if "previous_hidden_states" in model_kwargs:
        previous_hidden_states = model_kwargs["previous_hidden_states"]
lizhigong's avatar
lizhigong committed
256
257
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
lizhigong's avatar
lizhigong committed
258
259
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
260
261
262
263
264
265
266
267
268
269
    if intermediate_tensors != None:
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        intermediate_tensors_left = {}
        intermediate_tensors_right = {}
        for key in intermediate_tensors.tensors:
            split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
            intermediate_tensors_left[key] = split_intermediate_tensors[0]
            intermediate_tensors_right[key] = split_intermediate_tensors[1]
        intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
        intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
lizhigong's avatar
lizhigong committed
270

lizhigong's avatar
lizhigong committed
271
272
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
273
274
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
275
276
277
278
279
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
280
281
                                intermediate_tensors_left,
                                intermediate_tensors_right,
lizhigong's avatar
lizhigong committed
282
283
284
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
285
286
                                model_kwargs_left,
                                model_kwargs_right)
lizhigong's avatar
lizhigong committed
287
288
        tbo_obj.all_reduce()
        states_left, states_right = tbo_obj.get_model_output()
289

lizhigong's avatar
lizhigong committed
290
291
292
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
293
        tbo_obj.finish_thread()
lizhigong's avatar
lizhigong committed
294
295
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
296
    return hidden_or_intermediate_states