two_batch_overlap.py 12 KB
Newer Older
1
2
3
4
5
6

import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
7
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
8
9
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
10
from vllm.sequence import IntermediateTensors
11
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
12
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
13
14
from vllm.logger import init_logger
from vllm.profiler.prof import profile
15
from vllm import envs
16
17
18
19
20

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

21
22
23
tbo_step_stream = None
all_reduce_stream = None

24
25
class TwoBatchOverlap():
    def __init__(self):
26
27
        global tbo_step_stream
        global all_reduce_stream
28
29
30
31
32
33
34
35
36
37
38
39
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
40
41
42
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
43
        self.step_event = torch.cuda.Event(enable_timing=False)
44
45
46
47
48
49
50
51
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
52
53
54
55
56
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.left_thread.start()
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.right_thread.start()
        logger.info('tbo:two batch overlap start')
57
58

    def finish_thread(self):
59
60
61
62
        self.left_thread.join()
        self.left_thread = None
        self.right_thread.join()
        self.right_thread = None
63
64
65
66
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
67
        tid = threading.get_ident()
68
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
69
            self.left_tid = tid
70
71
72
            is_left_thread = True
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
73
            self.right_tid = tid
74
            init_tbo_forward_context(False, self.right_tid)
75
        with torch.cuda.stream(tbo_step_stream):
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            model_input = queue.get()
            profile.ProfRangePush('start')
            self.tbo_thread_synchronize(tid)
            model_kwargs = None
            intermediate_tensors = None
            if is_left_thread:
                model_kwargs = self.model_kwargs_left
                intermediate_tensors = self.intermediate_tensors_left
            else:
                model_kwargs = self.model_kwargs_right
                intermediate_tensors = self.intermediate_tensors_right
            with set_forward_context(model_input.attn_metadata,
                                    self.vllm_config, self.virtual_engine):
                hidden_or_intermediate_states = self.model_executable(
                    input_ids=model_input.input_tokens,
                    positions=model_input.input_positions,
                    intermediate_tensors=intermediate_tensors,
                    **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                    device=self.self_device),
                    **self.seqlen_agnostic_kwargs,
                    **model_kwargs,
                )
            if is_left_thread:
                self.sem_right.release()
                self.states_left_queue.put(hidden_or_intermediate_states)
            else:
                self.states_right_queue.put(hidden_or_intermediate_states)
            profile.ProfRangePop()
104

lizhigong's avatar
lizhigong committed
105
    def tbo_thread_synchronize(self, tid):
106
107
108
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
109
            self.left_first = False
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
127
128
                        intermediate_tensors_left,
                        intermediate_tensors_right,
129
130
131
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
132
133
                        model_kwargs_left,
                        model_kwargs_right):
134
135
136
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
137
138
        self.intermediate_tensors_left = intermediate_tensors_left
        self.intermediate_tensors_right = intermediate_tensors_right
139
140
141
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
lizhigong's avatar
lizhigong committed
142
143
        self.model_kwargs_left = model_kwargs_left
        self.model_kwargs_right = model_kwargs_right
144
145
146
147
148
149
150
151
152
153
154
155
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
tbo_obj = None

def init_two_batch_overlap():
    global tbo_obj
156
157
158
    if tbo_obj == None:
        tbo_obj = TwoBatchOverlap()
    tbo_obj.init_tbo_thread()
159
160

def tbo_all_reduce(obj):
161
    if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
162
163
164
165
166
167
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
168
169
170
171
172
173
            event_c2t.record()
            with torch.cuda.stream(all_reduce_stream):
                all_reduce_stream.wait_event(event_c2t)
                output = tensor_model_parallel_all_reduce(obj)
                event_t2c.record()
            tbo_obj.tbo_thread_synchronize(tid)
174
            tbo_step_stream.wait_event(event_t2c)
175
176
177
        else:
            output = tensor_model_parallel_all_reduce(obj)
            tbo_obj.tbo_thread_synchronize(tid)
178
179
180
181
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
182
183
184
185
186
187
188
    if isinstance(states_left, IntermediateTensors):
        output_map = {}
        for key in states_left.tensors:
            output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
        output = IntermediateTensors(output_map)
    else:
        output = torch.concat([states_left, states_right], dim=0)
189
190
191
192
193
194
195
196
197
198
199
200
201
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
202
203
204
    is_support = is_supported_attention_metadata(model_input.attn_metadata)
    if not is_support:
        logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
205
    batch_size = len(model_input.attn_metadata.seq_lens)
206
207
208
209
    is_decode_tbo_invalid = not model_input.is_prompt and (
        envs.VLLM_TBO_DECODE_BS < 2 or 
        batch_size < envs.VLLM_TBO_DECODE_BS or 
        model_input.attn_metadata.use_cuda_graph)
210
    if batch_size == 1 or \
211
212
        is_decode_tbo_invalid or \
        not is_support:
213
214
215
216
217
218
219
220
221
222
223
224
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
225
    profile.ProfRangePush('tbo_model_executable')
226
    init_two_batch_overlap()
227
228
229
230
231
232
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
233
    
234
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
235
    
lizhigong's avatar
lizhigong committed
236
237
    model_kwargs_left = model_kwargs.copy()
    model_kwargs_right = model_kwargs.copy()
238
239
    intermediate_tensors_left = None
    intermediate_tensors_right = None
lizhigong's avatar
lizhigong committed
240
241
    if "previous_hidden_states" in model_kwargs:
        previous_hidden_states = model_kwargs["previous_hidden_states"]
lizhigong's avatar
lizhigong committed
242
243
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
lizhigong's avatar
lizhigong committed
244
245
        model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
        model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
246
247
248
249
250
251
252
253
254
255
    if intermediate_tensors != None:
        query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
        intermediate_tensors_left = {}
        intermediate_tensors_right = {}
        for key in intermediate_tensors.tensors:
            split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
            intermediate_tensors_left[key] = split_intermediate_tensors[0]
            intermediate_tensors_right[key] = split_intermediate_tensors[1]
        intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
        intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
lizhigong's avatar
lizhigong committed
256

lizhigong's avatar
lizhigong committed
257
258
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
259
260
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
261
262
263
264
265
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
266
267
                                intermediate_tensors_left,
                                intermediate_tensors_right,
lizhigong's avatar
lizhigong committed
268
269
270
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
lizhigong's avatar
lizhigong committed
271
272
                                model_kwargs_left,
                                model_kwargs_right)
273
        
lizhigong's avatar
lizhigong committed
274
        states_left, states_right = tbo_obj.get_model_output()
275

lizhigong's avatar
lizhigong committed
276
277
278
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
279
        tbo_obj.finish_thread()
lizhigong's avatar
lizhigong committed
280
281
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
282
    return hidden_or_intermediate_states