Commit aa906d98 authored by lizhigong's avatar lizhigong
Browse files

add VLLM_TBO_DECODE_BS to support and setting the min bs on tbo decode

parent 59488cc9
...@@ -126,8 +126,8 @@ if TYPE_CHECKING: ...@@ -126,8 +126,8 @@ if TYPE_CHECKING:
VLLM_HAS_CONTEXT_DEFAULT: bool = False VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS:int = 0 VLLM_TBO_REQ_DELAY_MS: int = 0
VLLM_TBO_DECODE_BS: int = 0
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
...@@ -823,6 +823,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -823,6 +823,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TBO_REQ_DELAY_MS": "VLLM_TBO_REQ_DELAY_MS":
lambda: int(os.getenv("VLLM_TBO_REQ_DELAY_MS", "0")), lambda: int(os.getenv("VLLM_TBO_REQ_DELAY_MS", "0")),
# set the minimum batch size to enable TBO in decode, if < 2 , disable TBO in decode.
"VLLM_TBO_DECODE_BS":
lambda: int(os.getenv("VLLM_TBO_DECODE_BS", "0")),
# Enable zero overhead scheduler. # Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD": "VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))), lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
......
...@@ -14,8 +14,6 @@ from vllm.logger import init_logger ...@@ -14,8 +14,6 @@ from vllm.logger import init_logger
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm import envs from vllm import envs
enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,8 +29,6 @@ class TwoBatchOverlap(): ...@@ -31,8 +29,6 @@ class TwoBatchOverlap():
self.model_input_right_queue = queue.Queue() self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue() self.states_left_queue = queue.Queue()
self.states_right_queue = queue.Queue() self.states_right_queue = queue.Queue()
self.all_reduce_queue = queue.Queue()
self.all_reduce_out = queue.Queue()
self.left_thread = None self.left_thread = None
self.right_thread = None self.right_thread = None
self.left_tid = 0 self.left_tid = 0
...@@ -103,7 +99,6 @@ class TwoBatchOverlap(): ...@@ -103,7 +99,6 @@ class TwoBatchOverlap():
self.sem_right.release() self.sem_right.release()
self.states_left_queue.put(hidden_or_intermediate_states) self.states_left_queue.put(hidden_or_intermediate_states)
else: else:
self.all_reduce_queue.put(None)
self.states_right_queue.put(hidden_or_intermediate_states) self.states_right_queue.put(hidden_or_intermediate_states)
profile.ProfRangePop() profile.ProfRangePop()
...@@ -154,22 +149,6 @@ class TwoBatchOverlap(): ...@@ -154,22 +149,6 @@ class TwoBatchOverlap():
states_right = self.states_right_queue.get() states_right = self.states_right_queue.get()
return states_left, states_right return states_left, states_right
def all_reduce(self):
while True:
obj = self.all_reduce_queue.get()
if obj == None:
break
buf, event_c2t, event_t2c = obj
if tbo_one_stream:
output = tensor_model_parallel_all_reduce(buf)
else:
event_c2t.record()
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(buf)
event_t2c.record()
self.all_reduce_out.put(output)
tbo_obj = None tbo_obj = None
def init_two_batch_overlap(): def init_two_batch_overlap():
...@@ -186,11 +165,16 @@ def tbo_all_reduce(obj): ...@@ -186,11 +165,16 @@ def tbo_all_reduce(obj):
event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
else: else:
event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c]) event_c2t.record()
output = tbo_obj.all_reduce_out.get() with torch.cuda.stream(all_reduce_stream):
tbo_obj.tbo_thread_synchronize(tid) all_reduce_stream.wait_event(event_c2t)
if not tbo_one_stream: output = tensor_model_parallel_all_reduce(obj)
event_t2c.record()
tbo_obj.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c) tbo_step_stream.wait_event(event_t2c)
else:
output = tensor_model_parallel_all_reduce(obj)
tbo_obj.tbo_thread_synchronize(tid)
return output return output
return tensor_model_parallel_all_reduce(obj) return tensor_model_parallel_all_reduce(obj)
...@@ -218,12 +202,14 @@ def tbo_model_executable( ...@@ -218,12 +202,14 @@ def tbo_model_executable(
is_support = is_supported_attention_metadata(model_input.attn_metadata) is_support = is_supported_attention_metadata(model_input.attn_metadata)
if not is_support: if not is_support:
logger.info("tbo:not surpport yet ", type(model_input.attn_metadata)) logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
batch_size = len(model_input.attn_metadata.seq_lens) batch_size = len(model_input.attn_metadata.seq_lens)
is_decode_tbo_invalid = not model_input.is_prompt and (
envs.VLLM_TBO_DECODE_BS < 2 or
batch_size < envs.VLLM_TBO_DECODE_BS or
model_input.attn_metadata.use_cuda_graph)
if batch_size == 1 or \ if batch_size == 1 or \
(not model_input.is_prompt and not enable_tbo_decode) or \ is_decode_tbo_invalid or \
not is_support or \ not is_support:
is_cuda_graph_decode:
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
vllm_config, virtual_engine): vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
...@@ -284,7 +270,7 @@ def tbo_model_executable( ...@@ -284,7 +270,7 @@ def tbo_model_executable(
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
model_kwargs_left, model_kwargs_left,
model_kwargs_right) model_kwargs_right)
tbo_obj.all_reduce()
states_left, states_right = tbo_obj.get_model_output() states_left, states_right = tbo_obj.get_model_output()
hidden_or_intermediate_states = merge_model_output(states_left, states_right) hidden_or_intermediate_states = merge_model_output(states_left, states_right)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment