Commit 1a135a9d authored by maxiao1's avatar maxiao1 Committed by lizhigong
Browse files

token split by token adapt to pd separation & p2p can be used async

parent fc5bfc66
......@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_get_done_event
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
......@@ -36,6 +36,7 @@ class ReqMeta:
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
slot_mapping_device: torch.Tensor = None
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
......@@ -274,8 +275,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
# if envs.VLLM_ENABLE_TBO:
# slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
......@@ -287,41 +286,40 @@ class P2pNcclConnector(KVConnectorBase_V1):
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
if envs.VLLM_ENABLE_TBO:
send_stream = self.p2p_nccl_engine.send_stream
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
# tbo_evt = torch.cuda.Event(enable_timing=False)
# tbo_evt.record()
# with torch.cuda.stream(send_stream):
# send_stream.wait_event(tbo_evt) # 等 TBO all_reduce_stream 完成本轮
# kv_cache.record_stream(send_stream)
slot_mapping = request.slot_mapping
if request.slot_mapping_device is None:
request.slot_mapping_device = \
request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
slot_mapping = request.slot_mapping_device
kv_cache = extract_kv_from_layer(kv_layer, slot_mapping)
tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
kv_cache, remote_address, tbo_evt)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4))
kv_cache, ip + ":" + str(port + self._rank + 4), tbo_evt)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4))
kv_cache, ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i))
kv_cache, ip + ":" + str(port + i), tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else:
......
......@@ -20,7 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool)
from vllm.utils import current_stream, get_ip
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import all_reduce_stream as tbo_all_reduce_stream
from vllm import envs
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
......@@ -196,6 +196,7 @@ class P2pNcclEngine:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
......@@ -207,7 +208,7 @@ class P2pNcclEngine:
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue.append([tensor_id, remote_address, tensor, tbo_evt])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
......@@ -391,9 +392,11 @@ class P2pNcclEngine:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft()
tensor_id, remote_address, tensor, tbo_evt = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self):
......
......@@ -169,6 +169,7 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1114,6 +1115,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
}
# --8<-- [end:env-vars-definition]
......
......@@ -323,11 +323,11 @@ def tbo_split_and_execute_model(
)
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# 真实 token
# real token nums
num_tokens_left = int(input_split.scheduler_output_left.total_num_scheduled_tokens)
num_tokens_right = int(input_split.scheduler_output_right.total_num_scheduled_tokens)
# 按左右半批切成两份
# split intermediate tensors
def _split_intermediate_tensors(it, l, r):
if it is None: return None, None
left_tensor_map, right_tensor_map = {}, {}
......
......@@ -17,7 +17,7 @@ logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
PERSIST_THREADS = os.getenv('VLLM_TBO_PERSIST_THREADS', '1') not in ('0','false','False','no','NO','')
STOP = object()
class TwoBatchOverlap:
......@@ -48,7 +48,7 @@ class TwoBatchOverlap:
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
if self._threads_started and PERSIST_THREADS:
if self._threads_started:
return
if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
......@@ -220,7 +220,6 @@ def tbo_all_reduce_v1(obj):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(obj)
event_t2c.record()
#tbo_mark_allreduce_done()
tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c)
return output
......@@ -281,18 +280,6 @@ def tbo_model_executable_v1(
return hidden_or_intermediate_states
_tbo_done_event = torch.cuda.Event(enable_timing=False)
def tbo_mark_allreduce_done():
"""Record completion of all_reduce_stream for external synchronization."""
global all_reduce_stream, _tbo_done_event
_tbo_done_event.record(all_reduce_stream)
def tbo_get_done_event():
"""Return the event recorded by all_reduce_stream."""
return _tbo_done_event
def finalize_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 is not None:
......
......@@ -1295,7 +1295,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
profile.StartTracer()
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
......@@ -1574,7 +1574,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step()
print('###valid_sampled_token_ids', valid_sampled_token_ids)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment