Commit 1a135a9d authored by maxiao1's avatar maxiao1 Committed by lizhigong
Browse files

token split by token adapt to pd separation & p2p can be used async

parent fc5bfc66
...@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context ...@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_get_done_event
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -36,6 +36,7 @@ class ReqMeta: ...@@ -36,6 +36,7 @@ class ReqMeta:
token_ids: torch.Tensor token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids # Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
slot_mapping_device: torch.Tensor = None
@staticmethod @staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
...@@ -274,8 +275,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -274,8 +275,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx) Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise. if MLA is not used, and (num_pages, page_size, xxx) otherwise.
""" """
# if envs.VLLM_ENABLE_TBO:
# slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1] num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...@@ -287,41 +286,40 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -287,41 +286,40 @@ class P2pNcclConnector(KVConnectorBase_V1):
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
send_stream = self.p2p_nccl_engine.send_stream
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
slot_mapping = request.slot_mapping
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) if request.slot_mapping_device is None:
# tbo_evt = torch.cuda.Event(enable_timing=False) request.slot_mapping_device = \
# tbo_evt.record() request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
# with torch.cuda.stream(send_stream): slot_mapping = request.slot_mapping_device
# send_stream.wait_event(tbo_evt) # 等 TBO all_reduce_stream 完成本轮 kv_cache = extract_kv_from_layer(kv_layer, slot_mapping)
# kv_cache.record_stream(send_stream) tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record()
pp_rank = (self.parallel_config.rank // pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \ self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1): if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address, tbo_evt)
elif (self.pp_size == 2): elif (self.pp_size == 2):
if (pp_rank == 0): if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4)) kv_cache, ip + ":" + str(port + self._rank + 4), tbo_evt)
else: else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4)) kv_cache, ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8): elif (self.pp_size == 8):
for i in range(8): for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i)) kv_cache, ip + ":" + str(port + i), tbo_evt)
else: else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!") print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else: else:
......
...@@ -20,7 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import ( ...@@ -20,7 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool) TensorMemoryPool)
from vllm.utils import current_stream, get_ip from vllm.utils import current_stream, get_ip
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import all_reduce_stream as tbo_all_reduce_stream from vllm import envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -196,6 +196,7 @@ class P2pNcclEngine: ...@@ -196,6 +196,7 @@ class P2pNcclEngine:
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
...@@ -207,7 +208,7 @@ class P2pNcclEngine: ...@@ -207,7 +208,7 @@ class P2pNcclEngine:
return self._send_sync(tensor_id, tensor, remote_address) return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC": elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv: with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor]) self.send_queue.append([tensor_id, remote_address, tensor, tbo_evt])
self.send_queue_cv.notify() self.send_queue_cv.notify()
else: # GET else: # GET
with self.send_store_cv: with self.send_store_cv:
...@@ -391,9 +392,11 @@ class P2pNcclEngine: ...@@ -391,9 +392,11 @@ class P2pNcclEngine:
with self.send_queue_cv: with self.send_queue_cv:
while not self.send_queue: while not self.send_queue:
self.send_queue_cv.wait() self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft() tensor_id, remote_address, tensor, tbo_evt = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_sync(tensor_id, tensor, remote_address) self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self): def wait_for_sent(self):
......
...@@ -169,6 +169,7 @@ if TYPE_CHECKING: ...@@ -169,6 +169,7 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1114,6 +1115,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1114,6 +1115,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -323,11 +323,11 @@ def tbo_split_and_execute_model( ...@@ -323,11 +323,11 @@ def tbo_split_and_execute_model(
) )
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector === # === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# 真实 token # real token nums
num_tokens_left = int(input_split.scheduler_output_left.total_num_scheduled_tokens) num_tokens_left = int(input_split.scheduler_output_left.total_num_scheduled_tokens)
num_tokens_right = int(input_split.scheduler_output_right.total_num_scheduled_tokens) num_tokens_right = int(input_split.scheduler_output_right.total_num_scheduled_tokens)
# 按左右半批切成两份 # split intermediate tensors
def _split_intermediate_tensors(it, l, r): def _split_intermediate_tensors(it, l, r):
if it is None: return None, None if it is None: return None, None
left_tensor_map, right_tensor_map = {}, {} left_tensor_map, right_tensor_map = {}, {}
......
...@@ -17,7 +17,7 @@ logger = init_logger(__name__) ...@@ -17,7 +17,7 @@ logger = init_logger(__name__)
tbo_step_stream = None tbo_step_stream = None
all_reduce_stream = None all_reduce_stream = None
PERSIST_THREADS = os.getenv('VLLM_TBO_PERSIST_THREADS', '1') not in ('0','false','False','no','NO','')
STOP = object() STOP = object()
class TwoBatchOverlap: class TwoBatchOverlap:
...@@ -48,7 +48,7 @@ class TwoBatchOverlap: ...@@ -48,7 +48,7 @@ class TwoBatchOverlap:
self.event_right_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self): def init_tbo_thread(self):
if self._threads_started and PERSIST_THREADS: if self._threads_started:
return return
if self.left_thread is None or not self.left_thread.is_alive(): if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
...@@ -220,7 +220,6 @@ def tbo_all_reduce_v1(obj): ...@@ -220,7 +220,6 @@ def tbo_all_reduce_v1(obj):
all_reduce_stream.wait_event(event_c2t) all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(obj) output = tensor_model_parallel_all_reduce(obj)
event_t2c.record() event_t2c.record()
#tbo_mark_allreduce_done()
tbo_obj_v1.tbo_thread_synchronize(tid) tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c) tbo_step_stream.wait_event(event_t2c)
return output return output
...@@ -281,18 +280,6 @@ def tbo_model_executable_v1( ...@@ -281,18 +280,6 @@ def tbo_model_executable_v1(
return hidden_or_intermediate_states return hidden_or_intermediate_states
_tbo_done_event = torch.cuda.Event(enable_timing=False)
def tbo_mark_allreduce_done():
"""Record completion of all_reduce_stream for external synchronization."""
global all_reduce_stream, _tbo_done_event
_tbo_done_event.record(all_reduce_stream)
def tbo_get_done_event():
"""Return the event recorded by all_reduce_stream."""
return _tbo_done_event
def finalize_two_batch_overlap(): def finalize_two_batch_overlap():
global tbo_obj_v1 global tbo_obj_v1
if tbo_obj_v1 is not None: if tbo_obj_v1 is not None:
......
...@@ -1295,7 +1295,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1295,7 +1295,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
profile.StartTracer()
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
...@@ -1574,7 +1574,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1574,7 +1574,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
self.eplb_step() self.eplb_step()
print('###valid_sampled_token_ids', valid_sampled_token_ids)
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment