two_batch_overlap.py 10.6 KB
Newer Older
1
2
3
4
5
6
7
8
9

import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
10
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
11
12
from vllm.logger import init_logger
from vllm.profiler.prof import profile
13
from vllm import envs
14
15
16
17
18
19
20

enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

21
22
23
tbo_step_stream = None
all_reduce_stream = None

24
25
class TwoBatchOverlap():
    def __init__(self):
26
27
        global tbo_step_stream
        global all_reduce_stream
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.all_reduce_queue = queue.Queue()
        self.all_reduce_out = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
42
43
44
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
45
        self.step_event = torch.cuda.Event(enable_timing=False)
46
47
48
49
50
51
52
53
54
55
56
57
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.left_thread.start()
        self.right_thread.start()
58
        logger.info('tbo:two batch overlap threads start')
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    def finish_thread(self):
        if self.left_thread != None:
            self.model_input_left_queue.put(None)
            self.left_thread.join()
            self.left_thread = None
        if self.right_thread != None:
            self.model_input_right_queue.put(None)
            self.right_thread.join()
            self.right_thread = None
        logger.info('tbo:finish threads')
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
74
        tid = threading.get_ident()
75
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
76
            self.left_tid = tid
77
78
79
            is_left_thread = True
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
80
            self.right_tid = tid
81
            init_tbo_forward_context(False, self.right_tid)
82
        with torch.cuda.stream(tbo_step_stream):
lizhigong's avatar
lizhigong committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            while True:
                model_input = queue.get()
                if model_input == None:
                    break
                profile.ProfRangePush('start')
                self.tbo_thread_synchronize(tid)
                with set_forward_context(model_input.attn_metadata,
                                        self.vllm_config, self.virtual_engine):
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=model_input.input_tokens,
                        positions=model_input.input_positions,
                        intermediate_tensors=self.intermediate_tensors,
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **self.seqlen_agnostic_kwargs,
                        **self.model_kwargs,
                    )
                if is_left_thread:
                    self.sem_right.release()
                    self.states_left_queue.put(hidden_or_intermediate_states)
                else:
                    self.all_reduce_queue.put(None)
                    self.states_right_queue.put(hidden_or_intermediate_states)
                profile.ProfRangePop()
107

lizhigong's avatar
lizhigong committed
108
    def tbo_thread_synchronize(self, tid):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            self.left_first = False
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
                        intermediate_tensors,
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
                        model_kwargs):
        if self.left_thread == None:
            self.init_tbo_thread()
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
        self.intermediate_tensors = intermediate_tensors
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
        self.model_kwargs = model_kwargs
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
    def all_reduce(self):
        while True:
            obj = self.all_reduce_queue.get()
            if obj == None:
                break
            buf, event_c2t, event_t2c = obj
            if tbo_one_stream:
                output = tensor_model_parallel_all_reduce(buf)
            else:
lizhigong's avatar
lizhigong committed
162
                event_c2t.record()
163
164
                with torch.cuda.stream(all_reduce_stream):
                    all_reduce_stream.wait_event(event_c2t)
165
166
167
168
169
170
171
                    output = tensor_model_parallel_all_reduce(buf)
                    event_t2c.record()
            self.all_reduce_out.put(output)

tbo_obj = None

def init_two_batch_overlap():
172
    if envs.VLLM_ENABLE_TBO:
173
174
175
176
177
178
179
180
181
182
183
        global tbo_obj
        if tbo_obj == None:
            tbo_obj = TwoBatchOverlap()

def finish_two_batch_overlap():
    global tbo_obj
    if tbo_obj != None:
        tbo_obj.finish_thread()
        tbo_obj = None

def tbo_all_reduce(obj):
184
    if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
185
186
187
188
189
190
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
191
192
        tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
        output = tbo_obj.all_reduce_out.get()
lizhigong's avatar
lizhigong committed
193
        tbo_obj.tbo_thread_synchronize(tid)
194
        if not tbo_one_stream:
195
            tbo_step_stream.wait_event(event_t2c)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
    output = torch.concat([states_left, states_right], dim=0)
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
214
215
216
    is_support = is_supported_attention_metadata(model_input.attn_metadata)
    if not is_support:
        logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
217
218
219
220
    is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt    
    batch_size = len(model_input.attn_metadata.seq_lens)
    if batch_size == 1 or \
        (not model_input.is_prompt and not enable_tbo_decode) or \
221
        not is_support or \
222
223
224
225
226
227
228
229
230
231
232
233
234
        is_cuda_graph_decode:
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
235
    profile.ProfRangePush('tbo_model_executable')
236
    init_two_batch_overlap()
237
238
239
240
241
242
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
243
    
244
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
245
246
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
247
248
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
249
250
251
252
253
254
255
256
257
258
259
260
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
                                intermediate_tensors,
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
                                model_kwargs)
        tbo_obj.all_reduce()
        states_left, states_right = tbo_obj.get_model_output()
261

lizhigong's avatar
lizhigong committed
262
263
264
265
266
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
267
    return hidden_or_intermediate_states