Commit 0d61a71c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-opt1'

V0.8.5 zero overhead

See merge request dcutoolkit/deeplearing/vllm!140
parents f5cbfe8f bd1e64d6
...@@ -127,8 +127,8 @@ if TYPE_CHECKING: ...@@ -127,8 +127,8 @@ if TYPE_CHECKING:
VLLM_FLASH_ATTN_BACKEND: bool = False VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_USE_NN: bool = False VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS:int = 0 VLLM_TBO_REQ_DELAY_MS: int = 0
VLLM_TBO_DECODE_BS: int = 0
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
...@@ -829,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -829,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TBO_REQ_DELAY_MS": "VLLM_TBO_REQ_DELAY_MS":
lambda: int(os.getenv("VLLM_TBO_REQ_DELAY_MS", "0")), lambda: int(os.getenv("VLLM_TBO_REQ_DELAY_MS", "0")),
# set the minimum batch size to enable TBO in decode, if < 2 , disable TBO in decode.
"VLLM_TBO_DECODE_BS":
lambda: int(os.getenv("VLLM_TBO_DECODE_BS", "0")),
# Enable zero overhead scheduler. # Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD": "VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))), lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
......
...@@ -14,6 +14,7 @@ from typing import Any, Callable, Optional, Union ...@@ -14,6 +14,7 @@ from typing import Any, Callable, Optional, Union
import msgspec import msgspec
import torch import torch
from vllm import envs
from vllm.inputs import SingletonInputs from vllm.inputs import SingletonInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
...@@ -809,18 +810,20 @@ class SequenceGroup: ...@@ -809,18 +810,20 @@ class SequenceGroup:
def set_last_token_time(self, now: float) -> None: def set_last_token_time(self, now: float) -> None:
"""Sets the last token time for Request level timings.""" """Sets the last token time for Request level timings."""
# If still in prefill phase, assertion fails. if not envs.VLLM_ZERO_OVERHEAD:
assert not self.is_prefill(), ( # If still in prefill phase, assertion fails.
"seq_group.set_last_token_time() should not be called " assert not self.is_prefill(), (
"if the seq_group is in prefill phase.") "seq_group.set_last_token_time() should not be called "
"if the seq_group is in prefill phase.")
self.last_token_latency = now - self.metrics.last_token_time self.last_token_latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now self.metrics.last_token_time = now
def get_last_token_latency(self) -> float: def get_last_token_latency(self) -> float:
"""Returns the latency of the last token.""" """Returns the latency of the last token."""
assert not self.is_prefill(), ( if not envs.VLLM_ZERO_OVERHEAD:
"seq_group.get_last_token_latency() should not be called " assert not self.is_prefill(), (
"if the seq_group is in prefill phase.") "seq_group.get_last_token_latency() should not be called "
"if the seq_group is in prefill phase.")
return self.last_token_latency return self.last_token_latency
def maybe_set_first_token_time(self, time: float) -> None: def maybe_set_first_token_time(self, time: float) -> None:
......
...@@ -14,8 +14,6 @@ from vllm.logger import init_logger ...@@ -14,8 +14,6 @@ from vllm.logger import init_logger
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm import envs from vllm import envs
enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,8 +29,6 @@ class TwoBatchOverlap(): ...@@ -31,8 +29,6 @@ class TwoBatchOverlap():
self.model_input_right_queue = queue.Queue() self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue() self.states_left_queue = queue.Queue()
self.states_right_queue = queue.Queue() self.states_right_queue = queue.Queue()
self.all_reduce_queue = queue.Queue()
self.all_reduce_out = queue.Queue()
self.left_thread = None self.left_thread = None
self.right_thread = None self.right_thread = None
self.left_tid = 0 self.left_tid = 0
...@@ -103,7 +99,6 @@ class TwoBatchOverlap(): ...@@ -103,7 +99,6 @@ class TwoBatchOverlap():
self.sem_right.release() self.sem_right.release()
self.states_left_queue.put(hidden_or_intermediate_states) self.states_left_queue.put(hidden_or_intermediate_states)
else: else:
self.all_reduce_queue.put(None)
self.states_right_queue.put(hidden_or_intermediate_states) self.states_right_queue.put(hidden_or_intermediate_states)
profile.ProfRangePop() profile.ProfRangePop()
...@@ -154,22 +149,6 @@ class TwoBatchOverlap(): ...@@ -154,22 +149,6 @@ class TwoBatchOverlap():
states_right = self.states_right_queue.get() states_right = self.states_right_queue.get()
return states_left, states_right return states_left, states_right
def all_reduce(self):
while True:
obj = self.all_reduce_queue.get()
if obj == None:
break
buf, event_c2t, event_t2c = obj
if tbo_one_stream:
output = tensor_model_parallel_all_reduce(buf)
else:
event_c2t.record()
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(buf)
event_t2c.record()
self.all_reduce_out.put(output)
tbo_obj = None tbo_obj = None
def init_two_batch_overlap(): def init_two_batch_overlap():
...@@ -186,11 +165,16 @@ def tbo_all_reduce(obj): ...@@ -186,11 +165,16 @@ def tbo_all_reduce(obj):
event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
else: else:
event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c]) event_c2t.record()
output = tbo_obj.all_reduce_out.get() with torch.cuda.stream(all_reduce_stream):
tbo_obj.tbo_thread_synchronize(tid) all_reduce_stream.wait_event(event_c2t)
if not tbo_one_stream: output = tensor_model_parallel_all_reduce(obj)
event_t2c.record()
tbo_obj.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c) tbo_step_stream.wait_event(event_t2c)
else:
output = tensor_model_parallel_all_reduce(obj)
tbo_obj.tbo_thread_synchronize(tid)
return output return output
return tensor_model_parallel_all_reduce(obj) return tensor_model_parallel_all_reduce(obj)
...@@ -218,12 +202,14 @@ def tbo_model_executable( ...@@ -218,12 +202,14 @@ def tbo_model_executable(
is_support = is_supported_attention_metadata(model_input.attn_metadata) is_support = is_supported_attention_metadata(model_input.attn_metadata)
if not is_support: if not is_support:
logger.info("tbo:not surpport yet ", type(model_input.attn_metadata)) logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
batch_size = len(model_input.attn_metadata.seq_lens) batch_size = len(model_input.attn_metadata.seq_lens)
is_decode_tbo_invalid = not model_input.is_prompt and (
envs.VLLM_TBO_DECODE_BS < 2 or
batch_size < envs.VLLM_TBO_DECODE_BS or
model_input.attn_metadata.use_cuda_graph)
if batch_size == 1 or \ if batch_size == 1 or \
(not model_input.is_prompt and not enable_tbo_decode) or \ is_decode_tbo_invalid or \
not is_support or \ not is_support:
is_cuda_graph_decode:
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
vllm_config, virtual_engine): vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
...@@ -284,7 +270,7 @@ def tbo_model_executable( ...@@ -284,7 +270,7 @@ def tbo_model_executable(
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
model_kwargs_left, model_kwargs_left,
model_kwargs_right) model_kwargs_right)
tbo_obj.all_reduce()
states_left, states_right = tbo_obj.get_model_output() states_left, states_right = tbo_obj.get_model_output()
hidden_or_intermediate_states = merge_model_output(states_left, states_right) hidden_or_intermediate_states = merge_model_output(states_left, states_right)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment