two_batch_overlap.py 10.9 KB
Newer Older
1
2
3
4
5

import os
import queue
import threading
import torch
lizhigong's avatar
lizhigong committed
6
from vllm.attention.backends.flashmla import FlashMLAMetadata
7
8
9
10
11
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
lizhigong's avatar
lizhigong committed
12
from vllm.two_batch_overlap.model_input_split import split_model_input
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from vllm.utils import async_tensor_h2d
from vllm.logger import init_logger
from vllm.profiler.prof import profile

enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'

enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'

tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'

logger = init_logger(__name__)

def is_enable_tbo():
    return enable_tbo

28
29
30
tbo_step_stream = None
all_reduce_stream = None

31
32
class TwoBatchOverlap():
    def __init__(self):
33
34
        global tbo_step_stream
        global all_reduce_stream
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        self.model_input_left_queue = queue.Queue()
        self.model_input_right_queue = queue.Queue()
        self.states_left_queue = queue.Queue()
        self.states_right_queue = queue.Queue()
        self.all_reduce_queue = queue.Queue()
        self.all_reduce_out = queue.Queue()
        self.left_thread = None
        self.right_thread = None
        self.left_tid = 0
        self.right_tid = 0
        self.sem_left = threading.Semaphore(0)
        self.sem_right = threading.Semaphore(0)
        self.left_first = False
        self.tbo_running = False
49
50
51
        if tbo_step_stream == None:
            tbo_step_stream = torch.cuda.Stream()
            all_reduce_stream = torch.cuda.Stream()
lizhigong's avatar
lizhigong committed
52
        self.step_event = torch.cuda.Event(enable_timing=False)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.event_left_c2t = torch.cuda.Event(enable_timing=False)
        self.event_right_c2t = torch.cuda.Event(enable_timing=False)
        self.event_left_t2c = torch.cuda.Event(enable_timing=False)
        self.event_right_t2c = torch.cuda.Event(enable_timing=False)

    def init_tbo_thread(self):
        self.model_input_left_queue.empty()
        self.model_input_right_queue.empty()
        self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
        self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
        self.left_thread.start()
        self.right_thread.start()

    def finish_thread(self):
        if self.left_thread != None:
            self.model_input_left_queue.put(None)
            self.left_thread.join()
            self.left_thread = None
        if self.right_thread != None:
            self.model_input_right_queue.put(None)
            self.right_thread.join()
            self.right_thread = None
        logger.info('tbo:finish threads')
        
    @torch.inference_mode()
    def thread_two_batch_overlap(self, queue):
        is_left_thread = False
lizhigong's avatar
lizhigong committed
80
        tid = threading.get_ident()
81
        if queue == self.model_input_left_queue:
lizhigong's avatar
lizhigong committed
82
            self.left_tid = tid
83
84
85
86
            is_left_thread = True
            logger.info('tbo:new thread %d', self.left_tid)
            init_tbo_forward_context(True, self.left_tid)
        else:
lizhigong's avatar
lizhigong committed
87
            self.right_tid = tid
88
89
            logger.info('tbo:new thread %d', self.right_tid)
            init_tbo_forward_context(False, self.right_tid)
90
        with torch.cuda.stream(tbo_step_stream):
lizhigong's avatar
lizhigong committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            while True:
                model_input = queue.get()
                if model_input == None:
                    break
                profile.ProfRangePush('start')
                self.tbo_thread_synchronize(tid)
                with set_forward_context(model_input.attn_metadata,
                                        self.vllm_config, self.virtual_engine):
                    hidden_or_intermediate_states = self.model_executable(
                        input_ids=model_input.input_tokens,
                        positions=model_input.input_positions,
                        intermediate_tensors=self.intermediate_tensors,
                        **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
                                                        device=self.self_device),
                        **self.seqlen_agnostic_kwargs,
                        **self.model_kwargs,
                    )
                if is_left_thread:
                    self.sem_right.release()
                    self.states_left_queue.put(hidden_or_intermediate_states)
                else:
                    self.all_reduce_queue.put(None)
                    self.states_right_queue.put(hidden_or_intermediate_states)
                profile.ProfRangePop()
115

lizhigong's avatar
lizhigong committed
116
    def tbo_thread_synchronize(self, tid):
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        if tid == self.left_tid:
            if not self.left_first:
                self.sem_right.release()
            profile.ProfRangePop()
            self.sem_left.acquire()
            profile.ProfRangePush('left')
            self.left_first = False
            return self.event_left_c2t, self.event_left_t2c
        else:
            self.sem_left.release()
            profile.ProfRangePop()
            self.sem_right.acquire()
            profile.ProfRangePush('right')
            return self.event_right_c2t, self.event_right_t2c

    def set_model_input(self,
                        model_input_left, 
                        model_input_right, 
                        vllm_config,
                        virtual_engine,
                        model_executable,
                        intermediate_tensors,
                        multi_modal_kwargs,
                        self_device,
                        seqlen_agnostic_kwargs,
                        model_kwargs):
        if self.left_thread == None:
            self.init_tbo_thread()
        self.vllm_config = vllm_config
        self.virtual_engine = virtual_engine
        self.model_executable = model_executable
        self.intermediate_tensors = intermediate_tensors
        self.multi_modal_kwargs = multi_modal_kwargs
        self.self_device = self_device
        self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
        self.model_kwargs = model_kwargs
        self.model_input_left_queue.put(model_input_left)
        self.model_input_right_queue.put(model_input_right)

    def get_model_output(self):
        states_left = self.states_left_queue.get()
        states_right = self.states_right_queue.get()
        return states_left, states_right
    
    def all_reduce(self):
        while True:
            obj = self.all_reduce_queue.get()
            if obj == None:
                break
            buf, event_c2t, event_t2c = obj
            if tbo_one_stream:
                output = tensor_model_parallel_all_reduce(buf)
            else:
lizhigong's avatar
lizhigong committed
170
                event_c2t.record()
171
172
                with torch.cuda.stream(all_reduce_stream):
                    all_reduce_stream.wait_event(event_c2t)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                    output = tensor_model_parallel_all_reduce(buf)
                    event_t2c.record()
            self.all_reduce_out.put(output)

tbo_obj = None

def init_two_batch_overlap():
    if enable_tbo:
        global tbo_obj
        if tbo_obj == None:
            tbo_obj = TwoBatchOverlap()

def finish_two_batch_overlap():
    global tbo_obj
    if tbo_obj != None:
        tbo_obj.finish_thread()
        tbo_obj = None

def tbo_all_reduce(obj):
    if enable_tbo and tbo_obj != None and tbo_obj.tbo_running:
lizhigong's avatar
lizhigong committed
193
194
195
196
197
198
        tid = threading.get_ident()
        if not tbo_one_stream:
            if tid == tbo_obj.left_tid:
                event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
            else:
                event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
199
200
        tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
        output = tbo_obj.all_reduce_out.get()
lizhigong's avatar
lizhigong committed
201
        tbo_obj.tbo_thread_synchronize(tid)
202
        if not tbo_one_stream:
203
            tbo_step_stream.wait_event(event_t2c)
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        return output
    return tensor_model_parallel_all_reduce(obj) 

def merge_model_output(states_left, states_right):
    output = torch.concat([states_left, states_right], dim=0)
    return output

def tbo_model_executable(
        model_input, 
        vllm_config,
        virtual_engine,
        model_executable,
        intermediate_tensors,
        multi_modal_kwargs,
        self_device,
        seqlen_agnostic_kwargs,
        model_kwargs,
    ):
    init_two_batch_overlap()
    is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
lizhigong's avatar
lizhigong committed
224
    is_mla_fa = isinstance(model_input.attn_metadata, FlashMLAMetadata)
225
226
227
228
    is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt    
    batch_size = len(model_input.attn_metadata.seq_lens)
    if batch_size == 1 or \
        (not model_input.is_prompt and not enable_tbo_decode) or \
lizhigong's avatar
lizhigong committed
229
        not (is_rocm_fa or is_mla_fa) or \
230
231
232
233
234
235
236
237
238
239
240
241
242
        is_cuda_graph_decode:
        with set_forward_context(model_input.attn_metadata,
                                    vllm_config, virtual_engine):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                device=self_device),
                **seqlen_agnostic_kwargs,
                **model_kwargs,
            )
        return hidden_or_intermediate_states
243
    profile.ProfRangePush('tbo_model_executable')
244
245
246
247
248
249
    tbo_obj.tbo_running = True
    tbo_obj.left_first = True
    batch_size_left = int(batch_size / 2)
    batch_size_right = batch_size_left
    if batch_size % 2 == 1:
        batch_size_right += 1
lizhigong's avatar
lizhigong committed
250
    
251
    model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
lizhigong's avatar
lizhigong committed
252
253
    tbo_obj.step_event.record()
    current_stream = torch.cuda.current_stream()
254
255
    with torch.cuda.stream(tbo_step_stream):
        tbo_step_stream.wait_event(tbo_obj.step_event)
lizhigong's avatar
lizhigong committed
256
257
258
259
260
261
262
263
264
265
266
267
        tbo_obj.set_model_input(model_input_left, 
                                model_input_right, 
                                vllm_config,
                                virtual_engine,
                                model_executable,
                                intermediate_tensors,
                                multi_modal_kwargs,
                                self_device,
                                seqlen_agnostic_kwargs,
                                model_kwargs)
        tbo_obj.all_reduce()
        states_left, states_right = tbo_obj.get_model_output()
268

lizhigong's avatar
lizhigong committed
269
270
271
272
273
        hidden_or_intermediate_states = merge_model_output(states_left, states_right)
        tbo_obj.tbo_running = False
        tbo_obj.step_event.record()
    current_stream.wait_event(tbo_obj.step_event)
    profile.ProfRangePop()
274
    return hidden_or_intermediate_states