Commit cf13152f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of ssh://10.16.6.30:10022/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents ca0fea07 99863602
......@@ -36,6 +36,7 @@ class ReqMeta:
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
slot_mapping_device: torch.Tensor = None
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
......@@ -274,8 +275,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if envs.VLLM_ENABLE_TBO:
slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
......@@ -286,6 +285,44 @@ class P2pNcclConnector(KVConnectorBase_V1):
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
slot_mapping = request.slot_mapping
if request.slot_mapping_device is None:
request.slot_mapping_device = \
request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
slot_mapping = request.slot_mapping_device
kv_cache = extract_kv_from_layer(kv_layer, slot_mapping)
tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4), tbo_evt)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i), tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
......
......@@ -20,6 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool)
from vllm.utils import current_stream, get_ip
from vllm import envs
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
......@@ -110,6 +111,7 @@ class P2pNcclEngine:
self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
# self.send_stream = tbo_all_reduce_stream
self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = self.config.get_from_extra_config(
......@@ -194,6 +196,7 @@ class P2pNcclEngine:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
......@@ -205,7 +208,7 @@ class P2pNcclEngine:
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue.append([tensor_id, remote_address, tensor, tbo_evt])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
......@@ -389,9 +392,11 @@ class P2pNcclEngine:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft()
tensor_id, remote_address, tensor, tbo_evt = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self):
......
......@@ -170,6 +170,7 @@ if TYPE_CHECKING:
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_ALLTOALL_EP: bool = False
VLLM_P2P_ASYNC: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1119,6 +1120,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_ALLTOALL_EP":
lambda: (os.environ.get("VLLM_USE_ALLTOALL_EP", "True").lower() in
("true", "1")),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
}
# --8<-- [end:env-vars-definition]
......
......@@ -75,6 +75,9 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x)
else:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
......
......@@ -165,7 +165,10 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x, residual)
else:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
......@@ -184,11 +187,9 @@ class RMSNorm(CustomOp):
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
......
......@@ -17,10 +17,12 @@ logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
STOP = object()
class TwoBatchOverlap:
def __init__(self):
global tbo_step_stream
global all_reduce_stream
global tbo_step_stream, all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
......@@ -29,12 +31,14 @@ class TwoBatchOverlap():
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self._stop_evt = threading.Event()
self._threads_started = False
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
if tbo_step_stream is None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
......@@ -44,35 +48,52 @@ class TwoBatchOverlap():
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
if self._threads_started:
return
if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_left_queue,), daemon=True)
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
if self.right_thread is None or not self.right_thread.is_alive():
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_right_queue,), daemon=True)
self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
self._threads_started = True
def finish_thread(self):
self.left_thread.join()
def shutdown(self, timeout=5.0):
self._stop_evt.set()
try:
self.model_input_left_queue.put(STOP)
self.model_input_right_queue.put(STOP)
except Exception:
pass
if self.left_thread is not None:
self.left_thread.join(timeout=timeout)
self.left_thread = None
self.right_thread.join()
if self.right_thread is not None:
self.right_thread.join(timeout=timeout)
self.right_thread = None
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
def thread_two_batch_overlap(self, q):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
if q is self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
init_tbo_forward_context(False, self.right_tid)
while not self._stop_evt.is_set():
item = q.get()
if item is STOP:
break
with torch.cuda.stream(tbo_step_stream):
queue.get()
self.tbo_thread_synchronize(tid)
if is_left_thread:
attn_metadata = self.attn_metadata_left
num_input_tokens = self.num_input_tokens_left
......@@ -84,20 +105,28 @@ class TwoBatchOverlap():
input_ids = self.input_ids_right
positions = self.positions_right
model_output = None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
# Select per-thread tensors (left/right) with backward-compatible fallback
if is_left_thread:
intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
else:
intermediate_tensors = getattr(self, 'intermediate_tensors_right', None)
if intermediate_tensors is None:
intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
skip_cuda_graphs=True,
):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=self.intermediate_tensors,
intermediate_tensors=intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(model_output)
......@@ -128,7 +157,8 @@ class TwoBatchOverlap():
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds):
inputs_embeds,
):
self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
......@@ -139,9 +169,14 @@ class TwoBatchOverlap():
self.positions_left = positions_left
self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds
if isinstance(intermediate_tensors, tuple):
self.intermediate_tensors_left, self.intermediate_tensors_right = intermediate_tensors
else:
self.intermediate_tensors_left = intermediate_tensors
self.intermediate_tensors_right = None
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
......@@ -150,15 +185,18 @@ class TwoBatchOverlap():
states_right = self.states_right_queue.get()
return states_left, states_right
tbo_obj_v1 = None
def is_enable_tbo_v1():
global tbo_obj_v1
return tbo_obj_v1 != None
return tbo_obj_v1 is not None
def init_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 == None:
if tbo_obj_v1 is None:
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
......@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 is not None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
......@@ -207,19 +245,19 @@ def tbo_model_executable_v1(
input_ids,
positions,
intermediate_tensors,
inputs_embeds
inputs_embeds,
):
init_two_batch_overlap()
tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream()
num_total_tokens = num_input_tokens_left + num_input_tokens_right
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
input_ids_left, input_ids_right = torch.split(input_ids[:num_total_tokens], tokens_split, dim=0)
positions_left, positions_right = torch.split(positions[:num_total_tokens], tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left,
attn_metadata_right,
......@@ -231,13 +269,21 @@ def tbo_model_executable_v1(
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds)
inputs_embeds,
)
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states
def finalize_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 is not None:
try:
tbo_obj_v1.shutdown()
finally:
tbo_obj_v1 = None
\ No newline at end of file
......@@ -70,7 +70,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.profiler.prof import profile
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
......@@ -108,6 +108,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
if envs.VLLM_P2P_ASYNC:
self.p2p_event = torch.cuda.Event(enable_timing=False)
self.p2p_stream = torch.cuda.Stream()
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(
......@@ -1299,6 +1302,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
# profile.StartTracer()
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
......@@ -1378,13 +1382,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
inputs_embeds, scheduler_output, intermediate_tensors,
skip_cuda_graphs)
elif envs.VLLM_P2P_ASYNC:
self.p2p_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.p2p_stream):
self.p2p_stream.wait_event(self.p2p_event)
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
self.p2p_event.record()
current_stream.wait_event(self.p2p_event)
else:
# Run the model.
# Use persistent buffers for CUDA graphs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment