two_batch_overlap.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9

import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
10
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
11
12
from vllm.logger import init_logger
from vllm.profiler.prof import profile
13
from vllm import envs
14
15
16
17
18
19
20

enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

21
22
23
tbo_step_stream = None
all_reduce_stream = None

24
25
class TwoBatchOverlap():
    def __init__(self):
26
27
        global tbo_step_stream
        global all_reduce_stream
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.all_reduce_queue = queue.Queue()
        self.all_reduce_out = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
42
43
44
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
45
        self.step_event = torch.cuda.Event(enable_timing=False)
46
47
48
49
50
51
52
53
54
55
56
57
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.left_thread.start()
        self.right_thread.start()
58
        logger.info('tbo:two batch overlap threads start')
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    def finish_thread(self):
        if self.left_thread != None:
            self.model_input_left_queue.put(None)
            self.left_thread.join()
            self.left_thread = None
        if self.right_thread != None:
            self.model_input_right_queue.put(None)
            self.right_thread.join()
            self.right_thread = None
        logger.info('tbo:finish threads')
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
74
        tid = threading.get_ident()
75
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
76
            self.left_tid = tid
77
78
79
            is_left_thread = True
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
80
            self.right_tid = tid
81
            init_tbo_forward_context(False, self.right_tid)
82
        with torch.cuda.stream(tbo_step_stream):
lizhigong's avatar
lizhigong committed
83
84
85
86
87
88
            while True:
                model_input = queue.get()
                if model_input == None:
                    break
                profile.ProfRangePush('start')
                self.tbo_thread_synchronize(tid)
lizhigong's avatar
lizhigong committed
89
90
91
92
93
                model_kwargs = None
                if is_left_thread:
                    model_kwargs = self.model_kwargs_left
                else:
                    model_kwargs = self.model_kwargs_right
lizhigong's avatar
lizhigong committed
94
95
96
97
98
99
100
101
102
                with set_forward_context(model_input.attn_metadata,
                                        self.vllm_config, self.virtual_engine):
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=model_input.input_tokens,
                        positions=model_input.input_positions,
                        intermediate_tensors=self.intermediate_tensors,
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **self.seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
103
                        **model_kwargs,
lizhigong's avatar
lizhigong committed
104
105
106
107
108
109
110
111
                    )
                if is_left_thread:
                    self.sem_right.release()
                    self.states_left_queue.put(hidden_or_intermediate_states)
                else:
                    self.all_reduce_queue.put(None)
                    self.states_right_queue.put(hidden_or_intermediate_states)
                profile.ProfRangePop()
112

lizhigong's avatar
lizhigong committed
113
    def tbo_thread_synchronize(self, tid):
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            self.left_first = False
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
                        intermediate_tensors,
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
139
140
                        model_kwargs_left,
                        model_kwargs_right):
141
142
143
144
145
146
147
148
149
        if self.left_thread == None:
            self.init_tbo_thread()
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
        self.intermediate_tensors = intermediate_tensors
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
lizhigong's avatar
lizhigong committed
150
151
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
    def all_reduce(self):
        while True:
            obj = self.all_reduce_queue.get()
            if obj == None:
                break
            buf, event_c2t, event_t2c = obj
            if tbo_one_stream:
                output = tensor_model_parallel_all_reduce(buf)
            else:
lizhigong's avatar
lizhigong committed
169
                event_c2t.record()
170
171
                with torch.cuda.stream(all_reduce_stream):
                    all_reduce_stream.wait_event(event_c2t)
172
173
174
175
176
177
178
                    output = tensor_model_parallel_all_reduce(buf)
                    event_t2c.record()
            self.all_reduce_out.put(output)

tbo_obj = None

def init_two_batch_overlap():
179
    if envs.VLLM_ENABLE_TBO:
180
181
182
183
184
185
186
187
188
189
190
        global tbo_obj
        if tbo_obj == None:
            tbo_obj = TwoBatchOverlap()

def finish_two_batch_overlap():
    global tbo_obj
    if tbo_obj != None:
        tbo_obj.finish_thread()
        tbo_obj = None

def tbo_all_reduce(obj):
191
    if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
192
193
194
195
196
197
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
198
199
        tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
        output = tbo_obj.all_reduce_out.get()
lizhigong's avatar
lizhigong committed
200
        tbo_obj.tbo_thread_synchronize(tid)
201
        if not tbo_one_stream:
202
            tbo_step_stream.wait_event(event_t2c)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
    output = torch.concat([states_left, states_right], dim=0)
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
221
222
223
    is_support = is_supported_attention_metadata(model_input.attn_metadata)
    if not is_support:
        logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
224
225
226
227
    is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt    
    batch_size = len(model_input.attn_metadata.seq_lens)
    if batch_size == 1 or \
        (not model_input.is_prompt and not enable_tbo_decode) or \
228
        not is_support or \
229
230
231
232
233
234
235
236
237
238
239
240
241
        is_cuda_graph_decode:
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
242
    profile.ProfRangePush('tbo_model_executable')
243
    init_two_batch_overlap()
244
245
246
247
248
249
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
250
    
251
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
252
253
254
255
256
257
258
259
260
261
    
    model_kwargs_left = model_kwargs
    model_kwargs_right = model_kwargs
    if "previous_hidden_states" in model_kwargs:
        previous_hidden_states = model_kwargs["previous_hidden_states"]
        batch_size_split = [batch_size_left, batch_size_right]
        split_previous_hidden_states = torch.split(previous_hidden_states, batch_size_split, dim=0)
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]

lizhigong's avatar
lizhigong committed
262
263
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
264
265
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
266
267
268
269
270
271
272
273
274
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
                                intermediate_tensors,
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
275
276
                                model_kwargs_left,
                                model_kwargs_right)
lizhigong's avatar
lizhigong committed
277
278
        tbo_obj.all_reduce()
        states_left, states_right = tbo_obj.get_model_output()
279

lizhigong's avatar
lizhigong committed
280
281
282
283
284
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
285
    return hidden_or_intermediate_states