Unverified Commit 17a57fd8 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[Perf] Optimize multimodal mm_inputs process in scheduler (#11910)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 32438eba
......@@ -29,6 +29,7 @@ from typing import Deque, Dict, List, Optional, Tuple, Union
import psutil
import setproctitle
import torch
import torch.distributed
import zmq
from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext
......@@ -378,6 +379,17 @@ class Scheduler(
self.pp_group = get_pp_group()
self.world_group = get_world_group()
# With DP attention enabled, the entry rank is attn_tp_rank==0;
# otherwise the entry rank is TP group local rank 0.
# For #11910, use the CPU communication group to broadcast VLM Python objects,
# avoiding any coupling with CUDA streams/devices.
if self.server_args.enable_dp_attention:
self.cpu_group = self.attn_tp_cpu_group
self.is_entry_rank = self.attn_tp_rank == 0
else:
self.cpu_group = self.tp_cpu_group
self.is_entry_rank = self.tp_group.rank == 0
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
set_random_seed(self.random_seed)
......@@ -1133,6 +1145,70 @@ class Scheduler(
self.max_req_len - len(req.origin_input_ids) - 1,
)
def _process_and_broadcast_mm_inputs(
self,
raw_mm_inputs: Optional[dict],
):
"""Materialize MultimodalInputs once on the entry rank and broadcast to others.
Entry rank:
- constructs MultimodalInputs.from_dict(raw_mm_inputs) once
- broadcasts to other ranks in self.cpu_group (if world_size > 1)
Non-entry ranks:
- receive the object via broadcast (if world_size > 1)
- otherwise (single-rank / no group) fall back to local from_dict
Returns:
MultimodalInputs | None
"""
if raw_mm_inputs is None:
return None
group_world_size = 1
try:
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and self.cpu_group is not None
):
group_world_size = torch.distributed.get_world_size(
group=self.cpu_group
)
except Exception as e:
logger.warning(
f"Failed to get world size in mm_inputs handling with {e}, fallback to 1."
)
# In case tp size > 1, all the Scheduler TP ranks runs the duplicated computing
# process in CPU which occupies the main thread CPU cycle. This computing logic
# merely needs to be run on TP0 and be broadcast to other TP ranks.
# Since the Scheduler is single-threaded, any large CPU cost will impact
# handling of other messages. For example, CPU hits 99.9% can significantly
# increase the CUDA kernel launch time.
if self.is_entry_rank:
# Only the entry rank materializes once from dict.
image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
# Broadcast to other TP ranks (use src=0 within the group).
if group_world_size > 1:
obj_list = [image_inputs]
torch.distributed.broadcast_object_list(
obj_list, src=0, group=self.cpu_group
)
image_inputs = obj_list[0]
else:
# Non-entry ranks: receive if group size > 1; otherwise materialize locally.
if group_world_size > 1:
obj_list = [None]
torch.distributed.broadcast_object_list(
obj_list, src=0, group=self.cpu_group
)
image_inputs = obj_list[0]
else:
image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
return image_inputs
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
......@@ -1214,7 +1290,9 @@ class Scheduler(
# Handle multimodal inputs
if recv_req.mm_inputs is not None:
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
image_inputs = self._process_and_broadcast_mm_inputs(recv_req.mm_inputs)
# The following steps are already fast, execute locally on each rank.
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
......@@ -1444,7 +1522,7 @@ class Scheduler(
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
image_inputs = self._process_and_broadcast_mm_inputs(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment