Commit 48eb976d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' into v0.9.2-dev-ds

parents 2eb579dd 99863602
...@@ -36,6 +36,7 @@ class ReqMeta: ...@@ -36,6 +36,7 @@ class ReqMeta:
token_ids: torch.Tensor token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids # Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
slot_mapping_device: torch.Tensor = None
@staticmethod @staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
...@@ -273,9 +274,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -273,9 +274,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx) Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise. if MLA is not used, and (num_pages, page_size, xxx) otherwise.
""" """
if envs.VLLM_ENABLE_TBO:
slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1] num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...@@ -286,34 +285,72 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -286,34 +285,72 @@ class P2pNcclConnector(KVConnectorBase_V1):
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
ip, port = self.parse_request_id(request_id, True) for request in connector_metadata.requests:
remote_address = ip + ":" + str(port + self._rank) request_id = request.request_id
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size slot_mapping = request.slot_mapping
) % self.parallel_config.pipeline_parallel_size if request.slot_mapping_device is None:
if (self.pp_size == 1): request.slot_mapping_device = \
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
kv_cache, remote_address) slot_mapping = request.slot_mapping_device
elif (self.pp_size == 2): kv_cache = extract_kv_from_layer(kv_layer, slot_mapping)
if (pp_rank == 0): tbo_evt = torch.cuda.Event(enable_timing=False)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, tbo_evt.record()
kv_cache, remote_address) pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4)) kv_cache, remote_address, tbo_evt)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4), tbo_evt)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i), tbo_evt)
else: else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, elif (self.pp_size == 2):
kv_cache, ip + ":" + str(port + self._rank - 4)) if (pp_rank == 0):
elif (self.pp_size == 8): self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
for i in range(8): kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i)) kv_cache, ip + ":" + str(port + self._rank + 4))
else: else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!") self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4))
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i))
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
def wait_for_save(self): def wait_for_save(self):
pass pass
......
...@@ -20,6 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import ( ...@@ -20,6 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool) TensorMemoryPool)
from vllm.utils import current_stream, get_ip from vllm.utils import current_stream, get_ip
from vllm import envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -110,6 +111,7 @@ class P2pNcclEngine: ...@@ -110,6 +111,7 @@ class P2pNcclEngine:
self.recv_store_cv = threading.Condition() self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
# self.send_stream = tbo_all_reduce_stream
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = self.config.get_from_extra_config( mem_pool_size_gb = self.config.get_from_extra_config(
...@@ -194,6 +196,7 @@ class P2pNcclEngine: ...@@ -194,6 +196,7 @@ class P2pNcclEngine:
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
...@@ -205,7 +208,7 @@ class P2pNcclEngine: ...@@ -205,7 +208,7 @@ class P2pNcclEngine:
return self._send_sync(tensor_id, tensor, remote_address) return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC": elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv: with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor]) self.send_queue.append([tensor_id, remote_address, tensor, tbo_evt])
self.send_queue_cv.notify() self.send_queue_cv.notify()
else: # GET else: # GET
with self.send_store_cv: with self.send_store_cv:
...@@ -389,9 +392,11 @@ class P2pNcclEngine: ...@@ -389,9 +392,11 @@ class P2pNcclEngine:
with self.send_queue_cv: with self.send_queue_cv:
while not self.send_queue: while not self.send_queue:
self.send_queue_cv.wait() self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft() tensor_id, remote_address, tensor, tbo_evt = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_sync(tensor_id, tensor, remote_address) self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self): def wait_for_sent(self):
......
...@@ -170,6 +170,7 @@ if TYPE_CHECKING: ...@@ -170,6 +170,7 @@ if TYPE_CHECKING:
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_MORI_EP: bool = False VLLM_USE_MORI_EP: bool = False
VLLM_P2P_ASYNC: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1126,6 +1127,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1126,6 +1127,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MORI_EP": "VLLM_USE_MORI_EP":
lambda: (os.environ.get("VLLM_USE_MORI_EP", "True").lower() in lambda: (os.environ.get("VLLM_USE_MORI_EP", "True").lower() in
("true", "1")), ("true", "1")),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -75,8 +75,11 @@ class SiluAndMul(CustomOp): ...@@ -75,8 +75,11 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2 if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return F.silu(x[..., :d]) * x[..., d:] return self.forward_cuda(x)
else:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
......
...@@ -165,38 +165,39 @@ class RMSNorm(CustomOp): ...@@ -165,38 +165,39 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
x = x.to(torch.float32) return self.forward_cuda(x, residual)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else: else:
return x, residual orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else:
return x, residual
def forward_cuda( def forward_cuda(
self, self,
......
...@@ -17,10 +17,12 @@ logger = init_logger(__name__) ...@@ -17,10 +17,12 @@ logger = init_logger(__name__)
tbo_step_stream = None tbo_step_stream = None
all_reduce_stream = None all_reduce_stream = None
class TwoBatchOverlap():
STOP = object()
class TwoBatchOverlap:
def __init__(self): def __init__(self):
global tbo_step_stream global tbo_step_stream, all_reduce_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue() self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue() self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue() self.states_left_queue = queue.Queue()
...@@ -29,12 +31,14 @@ class TwoBatchOverlap(): ...@@ -29,12 +31,14 @@ class TwoBatchOverlap():
self.right_thread = None self.right_thread = None
self.left_tid = 0 self.left_tid = 0
self.right_tid = 0 self.right_tid = 0
self._stop_evt = threading.Event()
self._threads_started = False
self.sem_left = threading.Semaphore(0) self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0) self.sem_right = threading.Semaphore(0)
self.left_first = False self.left_first = False
self.tbo_running = False self.tbo_running = False
self.tbo_in_capture = False self.tbo_in_capture = False
if tbo_step_stream == None: if tbo_step_stream is None:
tbo_step_stream = torch.cuda.Stream() tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream() all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False) self.step_event = torch.cuda.Event(enable_timing=False)
...@@ -44,60 +48,85 @@ class TwoBatchOverlap(): ...@@ -44,60 +48,85 @@ class TwoBatchOverlap():
self.event_right_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self): def init_tbo_thread(self):
self.model_input_left_queue.empty() if self._threads_started:
self.model_input_right_queue.empty() return
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,)) if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread.start() self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) args=(self.model_input_left_queue,), daemon=True)
self.right_thread.start() self.left_thread.start()
if get_tp_group().rank == 0: if self.right_thread is None or not self.right_thread.is_alive():
logger.info('tbo:two batch overlap start') self.right_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_right_queue,), daemon=True)
def finish_thread(self): self.right_thread.start()
self.left_thread.join() self._threads_started = True
self.left_thread = None
self.right_thread.join() def shutdown(self, timeout=5.0):
self.right_thread = None self._stop_evt.set()
try:
self.model_input_left_queue.put(STOP)
self.model_input_right_queue.put(STOP)
except Exception:
pass
if self.left_thread is not None:
self.left_thread.join(timeout=timeout)
self.left_thread = None
if self.right_thread is not None:
self.right_thread.join(timeout=timeout)
self.right_thread = None
@torch.inference_mode() @torch.inference_mode()
def thread_two_batch_overlap(self, queue): def thread_two_batch_overlap(self, q):
is_left_thread = False is_left_thread = False
tid = threading.get_ident() tid = threading.get_ident()
if queue == self.model_input_left_queue: if q is self.model_input_left_queue:
self.left_tid = tid self.left_tid = tid
is_left_thread = True is_left_thread = True
init_tbo_forward_context(True, self.left_tid) init_tbo_forward_context(True, self.left_tid)
else: else:
self.right_tid = tid self.right_tid = tid
init_tbo_forward_context(False, self.right_tid) init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
queue.get() while not self._stop_evt.is_set():
self.tbo_thread_synchronize(tid) item = q.get()
if is_left_thread: if item is STOP:
attn_metadata = self.attn_metadata_left break
num_input_tokens = self.num_input_tokens_left
input_ids = self.input_ids_left with torch.cuda.stream(tbo_step_stream):
positions = self.positions_left self.tbo_thread_synchronize(tid)
else:
attn_metadata = self.attn_metadata_right if is_left_thread:
num_input_tokens = self.num_input_tokens_right attn_metadata = self.attn_metadata_left
input_ids = self.input_ids_right num_input_tokens = self.num_input_tokens_left
positions = self.positions_right input_ids = self.input_ids_left
positions = self.positions_left
model_output = None else:
# Run the decoder. attn_metadata = self.attn_metadata_right
# Use persistent buffers for CUDA graphs. num_input_tokens = self.num_input_tokens_right
with set_forward_context(attn_metadata, input_ids = self.input_ids_right
self.model_runner.vllm_config, positions = self.positions_right
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp, # Select per-thread tensors (left/right) with backward-compatible fallback
skip_cuda_graphs=True): if is_left_thread:
model_output = self.model_runner.model( intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
input_ids=input_ids, else:
positions=positions, intermediate_tensors = getattr(self, 'intermediate_tensors_right', None)
intermediate_tensors=self.intermediate_tensors, if intermediate_tensors is None:
inputs_embeds=self.inputs_embeds, intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
)
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True,
):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread: if is_left_thread:
self.sem_right.release() self.sem_right.release()
self.states_left_queue.put(model_output) self.states_left_queue.put(model_output)
...@@ -117,18 +146,19 @@ class TwoBatchOverlap(): ...@@ -117,18 +146,19 @@ class TwoBatchOverlap():
return self.event_right_c2t, self.event_right_t2c return self.event_right_c2t, self.event_right_t2c
def set_model_input(self, def set_model_input(self,
model_runner, model_runner,
attn_metadata_left, attn_metadata_left,
attn_metadata_right, attn_metadata_right,
num_input_tokens_left, num_input_tokens_left,
num_input_tokens_right, num_input_tokens_right,
input_ids_left, input_ids_left,
input_ids_right, input_ids_right,
positions_left, positions_left,
positions_right, positions_right,
num_tokens_across_dp, num_tokens_across_dp,
intermediate_tensors, intermediate_tensors,
inputs_embeds): inputs_embeds,
):
self.model_runner = model_runner self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right self.attn_metadata_right = attn_metadata_right
...@@ -139,26 +169,34 @@ class TwoBatchOverlap(): ...@@ -139,26 +169,34 @@ class TwoBatchOverlap():
self.positions_left = positions_left self.positions_left = positions_left
self.positions_right = positions_right self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds self.inputs_embeds = inputs_embeds
if isinstance(intermediate_tensors, tuple):
self.intermediate_tensors_left, self.intermediate_tensors_right = intermediate_tensors
else:
self.intermediate_tensors_left = intermediate_tensors
self.intermediate_tensors_right = None
self.model_input_left_queue.put(None) self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None) self.model_input_right_queue.put(None)
def get_model_output(self): def get_model_output(self):
states_left = self.states_left_queue.get() states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get() states_right = self.states_right_queue.get()
return states_left, states_right return states_left, states_right
tbo_obj_v1 = None tbo_obj_v1 = None
def is_enable_tbo_v1(): def is_enable_tbo_v1():
global tbo_obj_v1 global tbo_obj_v1
return tbo_obj_v1 != None return tbo_obj_v1 is not None
def init_two_batch_overlap(): def init_two_batch_overlap():
global tbo_obj_v1 global tbo_obj_v1
if tbo_obj_v1 == None: if tbo_obj_v1 is None:
tbo_obj_v1 = TwoBatchOverlap() tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread() tbo_obj_v1.init_tbo_thread()
...@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache): ...@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj): def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running: if envs.VLLM_ENABLE_TBO and tbo_obj_v1 is not None and tbo_obj_v1.tbo_running:
tid = threading.get_ident() tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid: if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
...@@ -185,7 +223,7 @@ def tbo_all_reduce_v1(obj): ...@@ -185,7 +223,7 @@ def tbo_all_reduce_v1(obj):
tbo_obj_v1.tbo_thread_synchronize(tid) tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c) tbo_step_stream.wait_event(event_t2c)
return output return output
return tensor_model_parallel_all_reduce(obj) return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right): def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors): if isinstance(states_left, IntermediateTensors):
...@@ -199,45 +237,53 @@ def merge_model_output(states_left, states_right): ...@@ -199,45 +237,53 @@ def merge_model_output(states_left, states_right):
def tbo_model_executable_v1( def tbo_model_executable_v1(
model_runner, model_runner,
attn_metadata_left, attn_metadata_left,
attn_metadata_right, attn_metadata_right,
num_input_tokens_left, num_input_tokens_left,
num_input_tokens_right, num_input_tokens_right,
num_tokens_across_dp, num_tokens_across_dp,
input_ids, input_ids,
positions, positions,
intermediate_tensors, intermediate_tensors,
inputs_embeds inputs_embeds,
): ):
init_two_batch_overlap() init_two_batch_overlap()
tbo_obj_v1.tbo_running = True tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record() tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
num_total_tokens = num_input_tokens_left + num_input_tokens_right
with torch.cuda.stream(tbo_step_stream): with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event) tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right] tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0) input_ids_left, input_ids_right = torch.split(input_ids[:num_total_tokens], tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0) positions_left, positions_right = torch.split(positions[:num_total_tokens], tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner, tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left, attn_metadata_left,
attn_metadata_right, attn_metadata_right,
num_input_tokens_left, num_input_tokens_left,
num_input_tokens_right, num_input_tokens_right,
input_ids_left, input_ids_left,
input_ids_right, input_ids_right,
positions_left, positions_left,
positions_right, positions_right,
num_tokens_across_dp, num_tokens_across_dp,
intermediate_tensors, intermediate_tensors,
inputs_embeds) inputs_embeds,
)
model_output_left, model_output_right = tbo_obj_v1.get_model_output() model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right) hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record() tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event) current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states return hidden_or_intermediate_states
\ No newline at end of file
def finalize_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 is not None:
try:
tbo_obj_v1.shutdown()
finally:
tbo_obj_v1 = None
\ No newline at end of file
...@@ -69,7 +69,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch ...@@ -69,7 +69,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.profiler.prof import profile
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
...@@ -107,6 +107,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -107,6 +107,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
if envs.VLLM_P2P_ASYNC:
self.p2p_event = torch.cuda.Event(enable_timing=False)
self.p2p_stream = torch.cuda.Stream()
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes( set_cpu_offload_max_bytes(
...@@ -1298,6 +1301,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1298,6 +1301,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
# profile.StartTracer()
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
...@@ -1377,13 +1381,40 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1377,13 +1381,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions, num_tokens_across_dp, input_ids, positions,
inputs_embeds, scheduler_output, intermediate_tensors, inputs_embeds, scheduler_output, intermediate_tensors,
skip_cuda_graphs) skip_cuda_graphs)
elif envs.VLLM_P2P_ASYNC:
self.p2p_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.p2p_stream):
self.p2p_stream.wait_event(self.p2p_event)
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
self.p2p_event.record()
current_stream.wait_event(self.p2p_event)
else: else:
# Run the model. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment