Commit 828aeaae authored by lizhigong's avatar lizhigong
Browse files

优化stream的初始化和warmup方式

parent 56ffc380
......@@ -24,8 +24,13 @@ logger = init_logger(__name__)
def is_enable_tbo():
return enable_tbo
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
def __init__(self):
global tbo_step_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
......@@ -40,8 +45,9 @@ class TwoBatchOverlap():
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.stream = torch.cuda.Stream()
self.step_stream = torch.cuda.Stream()
if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
......@@ -80,7 +86,7 @@ class TwoBatchOverlap():
self.right_tid = tid
logger.info('tbo:new thread %d', self.right_tid)
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(self.step_stream):
with torch.cuda.stream(tbo_step_stream):
while True:
model_input = queue.get()
if model_input == None:
......@@ -161,8 +167,8 @@ class TwoBatchOverlap():
output = tensor_model_parallel_all_reduce(buf)
else:
event_c2t.record()
with torch.cuda.stream(self.stream):
self.stream.wait_event(event_c2t)
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(buf)
event_t2c.record()
self.all_reduce_out.put(output)
......@@ -193,7 +199,7 @@ def tbo_all_reduce(obj):
output = tbo_obj.all_reduce_out.get()
tbo_obj.tbo_thread_synchronize(tid)
if not tbo_one_stream:
tbo_obj.step_stream.wait_event(event_t2c)
tbo_step_stream.wait_event(event_t2c)
return output
return tensor_model_parallel_all_reduce(obj)
......@@ -418,7 +424,6 @@ def tbo_model_executable(
seqlen_agnostic_kwargs,
model_kwargs,
):
profile.ProfRangePush('tbo_model_executable')
init_two_batch_overlap()
is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
......@@ -439,6 +444,7 @@ def tbo_model_executable(
**model_kwargs,
)
return hidden_or_intermediate_states
profile.ProfRangePush('tbo_model_executable')
tbo_obj.tbo_running = True
tbo_obj.left_first = True
batch_size_left = int(batch_size / 2)
......@@ -446,11 +452,11 @@ def tbo_model_executable(
if batch_size % 2 == 1:
batch_size_right += 1
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_obj.step_stream):
tbo_obj.step_stream.wait_event(tbo_obj.step_event)
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj.step_event)
tbo_obj.set_model_input(model_input_left,
model_input_right,
vllm_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment