Commit 56ffc380 authored by lizhigong's avatar lizhigong
Browse files

调试tbo正确性

parent 2a935929
......@@ -41,6 +41,8 @@ class TwoBatchOverlap():
self.left_first = False
self.tbo_running = False
self.stream = torch.cuda.Stream()
self.step_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
......@@ -68,21 +70,23 @@ class TwoBatchOverlap():
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
self.left_tid = threading.get_ident()
self.left_tid = tid
is_left_thread = True
logger.info('tbo:new thread %d', self.left_tid)
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = threading.get_ident()
self.right_tid = tid
logger.info('tbo:new thread %d', self.right_tid)
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(self.step_stream):
while True:
model_input = queue.get()
if model_input == None:
break
profile.ProfRangePush('start')
self.tbo_thread_synchronize(False)
self.tbo_thread_synchronize(tid)
with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
......@@ -94,20 +98,16 @@ class TwoBatchOverlap():
**self.seqlen_agnostic_kwargs,
**self.model_kwargs,
)
profile.ProfRangePush('end')
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(hidden_or_intermediate_states)
else:
self.all_reduce_queue.put(None)
self.states_right_queue.put(hidden_or_intermediate_states)
profile.ProfRangePop()
def tbo_thread_synchronize(self, recode_flag = True):
tid = threading.get_ident()
def tbo_thread_synchronize(self, tid):
if tid == self.left_tid:
if recode_flag and not tbo_one_stream:
print('###left_c2t_recorded')
self.event_left_c2t.record()
if not self.left_first:
self.sem_right.release()
profile.ProfRangePop()
......@@ -116,9 +116,6 @@ class TwoBatchOverlap():
self.left_first = False
return self.event_left_c2t, self.event_left_t2c
else:
if recode_flag and not tbo_one_stream:
print('###right_c2t_recorded')
self.event_right_c2t.record()
self.sem_left.release()
profile.ProfRangePop()
self.sem_right.acquire()
......@@ -160,17 +157,14 @@ class TwoBatchOverlap():
if obj == None:
break
buf, event_c2t, event_t2c = obj
#print('###buf', buf[0,0:5])
if tbo_one_stream:
output = tensor_model_parallel_all_reduce(buf)
else:
event_c2t.record()
with torch.cuda.stream(self.stream):
print('###stream.wait_event event_c2t before all_reduce')
self.stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(buf)
print('###event_t2c recorded')
event_t2c.record()
#print('###print', output[0,0:5])
self.all_reduce_out.put(output)
tbo_obj = None
......@@ -189,13 +183,17 @@ def finish_two_batch_overlap():
def tbo_all_reduce(obj):
if enable_tbo and tbo_obj != None and tbo_obj.tbo_running:
event_c2t, event_t2c = tbo_obj.tbo_thread_synchronize()
tid = threading.get_ident()
if not tbo_one_stream:
if tid == tbo_obj.left_tid:
event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
else:
event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
output = tbo_obj.all_reduce_out.get()
tbo_obj.tbo_thread_synchronize(tid)
if not tbo_one_stream:
current_stream = torch.cuda.current_stream()
print('###current_stream wait event event_t2c')
current_stream.wait_event(event_t2c)
tbo_obj.step_stream.wait_event(event_t2c)
return output
return tensor_model_parallel_all_reduce(obj)
......@@ -420,6 +418,7 @@ def tbo_model_executable(
seqlen_agnostic_kwargs,
model_kwargs,
):
profile.ProfRangePush('tbo_model_executable')
init_two_batch_overlap()
is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
......@@ -446,6 +445,11 @@ def tbo_model_executable(
batch_size_right = batch_size_left
if batch_size % 2 == 1:
batch_size_right += 1
tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_obj.step_stream):
tbo_obj.step_stream.wait_event(tbo_obj.step_event)
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
tbo_obj.set_model_input(model_input_left,
model_input_right,
......@@ -462,4 +466,7 @@ def tbo_model_executable(
hidden_or_intermediate_states = merge_model_output(states_left, states_right)
tbo_obj.tbo_running = False
tbo_obj.step_event.record()
current_stream.wait_event(tbo_obj.step_event)
profile.ProfRangePop()
return hidden_or_intermediate_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment