"vscode:/vscode.git/clone" did not exist on "aecdff1869c6ae2c923e9d6f164d20f3dd917bcf"
Commit 828aeaae authored by lizhigong's avatar lizhigong
Browse files

优化stream的初始化和warmup方式

parent 56ffc380
...@@ -24,8 +24,13 @@ logger = init_logger(__name__) ...@@ -24,8 +24,13 @@ logger = init_logger(__name__)
def is_enable_tbo(): def is_enable_tbo():
return enable_tbo return enable_tbo
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap(): class TwoBatchOverlap():
def __init__(self): def __init__(self):
global tbo_step_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue() self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue() self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue() self.states_left_queue = queue.Queue()
...@@ -40,8 +45,9 @@ class TwoBatchOverlap(): ...@@ -40,8 +45,9 @@ class TwoBatchOverlap():
self.sem_right = threading.Semaphore(0) self.sem_right = threading.Semaphore(0)
self.left_first = False self.left_first = False
self.tbo_running = False self.tbo_running = False
self.stream = torch.cuda.Stream() if tbo_step_stream == None:
self.step_stream = torch.cuda.Stream() tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False) self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False) self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False) self.event_right_c2t = torch.cuda.Event(enable_timing=False)
...@@ -80,7 +86,7 @@ class TwoBatchOverlap(): ...@@ -80,7 +86,7 @@ class TwoBatchOverlap():
self.right_tid = tid self.right_tid = tid
logger.info('tbo:new thread %d', self.right_tid) logger.info('tbo:new thread %d', self.right_tid)
init_tbo_forward_context(False, self.right_tid) init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(self.step_stream): with torch.cuda.stream(tbo_step_stream):
while True: while True:
model_input = queue.get() model_input = queue.get()
if model_input == None: if model_input == None:
...@@ -161,8 +167,8 @@ class TwoBatchOverlap(): ...@@ -161,8 +167,8 @@ class TwoBatchOverlap():
output = tensor_model_parallel_all_reduce(buf) output = tensor_model_parallel_all_reduce(buf)
else: else:
event_c2t.record() event_c2t.record()
with torch.cuda.stream(self.stream): with torch.cuda.stream(all_reduce_stream):
self.stream.wait_event(event_c2t) all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(buf) output = tensor_model_parallel_all_reduce(buf)
event_t2c.record() event_t2c.record()
self.all_reduce_out.put(output) self.all_reduce_out.put(output)
...@@ -193,7 +199,7 @@ def tbo_all_reduce(obj): ...@@ -193,7 +199,7 @@ def tbo_all_reduce(obj):
output = tbo_obj.all_reduce_out.get() output = tbo_obj.all_reduce_out.get()
tbo_obj.tbo_thread_synchronize(tid) tbo_obj.tbo_thread_synchronize(tid)
if not tbo_one_stream: if not tbo_one_stream:
tbo_obj.step_stream.wait_event(event_t2c) tbo_step_stream.wait_event(event_t2c)
return output return output
return tensor_model_parallel_all_reduce(obj) return tensor_model_parallel_all_reduce(obj)
...@@ -418,7 +424,6 @@ def tbo_model_executable( ...@@ -418,7 +424,6 @@ def tbo_model_executable(
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
model_kwargs, model_kwargs,
): ):
profile.ProfRangePush('tbo_model_executable')
init_two_batch_overlap() init_two_batch_overlap()
is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata) is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
...@@ -439,6 +444,7 @@ def tbo_model_executable( ...@@ -439,6 +444,7 @@ def tbo_model_executable(
**model_kwargs, **model_kwargs,
) )
return hidden_or_intermediate_states return hidden_or_intermediate_states
profile.ProfRangePush('tbo_model_executable')
tbo_obj.tbo_running = True tbo_obj.tbo_running = True
tbo_obj.left_first = True tbo_obj.left_first = True
batch_size_left = int(batch_size / 2) batch_size_left = int(batch_size / 2)
...@@ -446,11 +452,11 @@ def tbo_model_executable( ...@@ -446,11 +452,11 @@ def tbo_model_executable(
if batch_size % 2 == 1: if batch_size % 2 == 1:
batch_size_right += 1 batch_size_right += 1
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
tbo_obj.step_event.record() tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_obj.step_stream): with torch.cuda.stream(tbo_step_stream):
tbo_obj.step_stream.wait_event(tbo_obj.step_event) tbo_step_stream.wait_event(tbo_obj.step_event)
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
tbo_obj.set_model_input(model_input_left, tbo_obj.set_model_input(model_input_left,
model_input_right, model_input_right,
vllm_config, vllm_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment