Unverified Commit 80955ef4 authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

perf: Keep embeddings on GPU Embedding Sender in EPD pipeline + minor fixes (#6535)

parent 5734f5c4
...@@ -361,13 +361,11 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender): ...@@ -361,13 +361,11 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
Returns: Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed. A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
""" """
# If not staging embedding and embedding is on CPU, we explicitly copy if stage_embeddings:
# the tensor as torch.Tensor.cpu() will return original tensor if it's already on CPU transfer_buf = embeddings
if not stage_embeddings and not embeddings.is_cuda:
embeddings_cpu = embeddings.clone().detach()
else: else:
embeddings_cpu = embeddings.cpu() transfer_buf = embeddings.clone().detach()
descriptor = nixl_connect.Descriptor(embeddings_cpu) descriptor = nixl_connect.Descriptor(transfer_buf)
readable_op = await self.connector.create_readable(descriptor) readable_op = await self.connector.create_readable(descriptor)
request = TransferRequest( request = TransferRequest(
......
...@@ -42,7 +42,7 @@ ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1)) ...@@ -42,7 +42,7 @@ ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
class EmbeddingItem: class EmbeddingItem:
key: str key: str
image_grid_thw: list image_grid_thw: list
embeddings_cpu: torch.Tensor embeddings: torch.Tensor
class EncodeWorkerHandler: class EncodeWorkerHandler:
...@@ -140,11 +140,11 @@ class EncodeWorkerHandler: ...@@ -140,11 +140,11 @@ class EncodeWorkerHandler:
if self.embedding_cache is not None and self.embedding_cache.has_key( if self.embedding_cache is not None and self.embedding_cache.has_key(
embedding_key embedding_key
): ):
(image_grid_thw, embeddings_cpu) = self.embedding_cache.get( (image_grid_thw, embeddings) = self.embedding_cache.get(
embedding_key embedding_key
) )
embedding_lists[idx] = EmbeddingItem( embedding_lists[idx] = EmbeddingItem(
embedding_key, image_grid_thw, embeddings_cpu embedding_key, image_grid_thw, embeddings
) )
# compute # compute
else: else:
...@@ -200,7 +200,7 @@ class EncodeWorkerHandler: ...@@ -200,7 +200,7 @@ class EncodeWorkerHandler:
// merge_size // merge_size
// merge_size // merge_size
).tolist() ).tolist()
splitted_embeddings = embeddings.cpu().squeeze(0).split(sizes) splitted_embeddings = embeddings.squeeze(0).split(sizes)
logger.debug( logger.debug(
f"Splitted embeddings lengths: {[e.shape for e in splitted_embeddings]}" f"Splitted embeddings lengths: {[e.shape for e in splitted_embeddings]}"
) )
...@@ -209,7 +209,7 @@ class EncodeWorkerHandler: ...@@ -209,7 +209,7 @@ class EncodeWorkerHandler:
# embeddings already has batch dimension for images, so we can directly # embeddings already has batch dimension for images, so we can directly
# split by batch dimension # split by batch dimension
logger.debug(f"image embedding shape: {embeddings.shape}") logger.debug(f"image embedding shape: {embeddings.shape}")
splitted_embeddings = embeddings.cpu() splitted_embeddings = embeddings
image_grid_thw = ( image_grid_thw = (
image_embeds["image_grid_thw"].tolist() image_embeds["image_grid_thw"].tolist()
...@@ -230,7 +230,7 @@ class EncodeWorkerHandler: ...@@ -230,7 +230,7 @@ class EncodeWorkerHandler:
embedding_lists[list_idx].key, embedding_lists[list_idx].key,
( (
embedding_lists[list_idx].image_grid_thw, embedding_lists[list_idx].image_grid_thw,
embedding_lists[list_idx].embeddings_cpu, embedding_lists[list_idx].embeddings,
), ),
) )
...@@ -240,7 +240,7 @@ class EncodeWorkerHandler: ...@@ -240,7 +240,7 @@ class EncodeWorkerHandler:
send_tasks = [ send_tasks = [
asyncio.create_task( asyncio.create_task(
self.embedding_sender.send_embeddings( self.embedding_sender.send_embeddings(
embedding_item.embeddings_cpu, stage_embeddings=True embedding_item.embeddings, stage_embeddings=True
) )
) )
for embedding_item in embedding_lists for embedding_item in embedding_lists
...@@ -252,7 +252,7 @@ class EncodeWorkerHandler: ...@@ -252,7 +252,7 @@ class EncodeWorkerHandler:
for idx, item in enumerate(zip(embedding_lists, transfer_requests)): for idx, item in enumerate(zip(embedding_lists, transfer_requests)):
embedding_item, transfer_request = item embedding_item, transfer_request = item
logger.debug( logger.debug(
f"{embedding_item.embeddings_cpu.shape} prepared for transfer." f"{embedding_item.embeddings.shape} prepared for transfer."
) )
# Update request for transfer metadata # Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None request.multimodal_inputs[idx].multimodal_input.image_url = None
...@@ -260,13 +260,13 @@ class EncodeWorkerHandler: ...@@ -260,13 +260,13 @@ class EncodeWorkerHandler:
idx idx
].image_grid_thw = embedding_item.image_grid_thw ].image_grid_thw = embedding_item.image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple( request.multimodal_inputs[idx].embeddings_shape = tuple(
embedding_item.embeddings_cpu.shape embedding_item.embeddings.shape
) )
request.multimodal_inputs[idx].serialized_request = transfer_request[0] request.multimodal_inputs[idx].serialized_request = transfer_request[0]
# Keep a reference of the embedding_cpu and only drop reference when the transfer is done # Keep a reference of the embedding and only drop reference when the transfer is done
self.send_complete_queue.put_nowait( self.send_complete_queue.put_nowait(
(transfer_request[1], embedding_item.embeddings_cpu) (transfer_request[1], embedding_item.embeddings)
) )
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
......
...@@ -129,6 +129,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -129,6 +129,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
for item in mm_data.get(IMAGE_URL_KEY, []): for item in mm_data.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and "Url" in item: if isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"]) image_urls.append(item["Url"])
elif isinstance(item, dict) and "Decoded" in item:
image_urls.append(item["Decoded"])
sampling_params = build_sampling_params( sampling_params = build_sampling_params(
raw_request, self.default_sampling_params raw_request, self.default_sampling_params
......
...@@ -33,6 +33,10 @@ VLLM_ENCODER = int(os.getenv("VLLM_ENCODER", 1)) ...@@ -33,6 +33,10 @@ VLLM_ENCODER = int(os.getenv("VLLM_ENCODER", 1))
class SupportedModels: class SupportedModels:
"""Supported multimodal model identifiers""" """Supported multimodal model identifiers"""
# TODO: Replace this explicit model list with dynamic detection using
# HF config `architectures` field or vLLM's model registry, so any
# vLLM-supported VLM works without maintaining entries here.
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf" LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct" QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
...@@ -42,6 +46,8 @@ class SupportedModels: ...@@ -42,6 +46,8 @@ class SupportedModels:
QWEN_3_VL_30B_A3B = "Qwen/Qwen3-VL-30B-A3B-Instruct" QWEN_3_VL_30B_A3B = "Qwen/Qwen3-VL-30B-A3B-Instruct"
QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8" QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
QWEN_3_VL_8B_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8" QWEN_3_VL_8B_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8"
QWEN_3_VL_4B = "Qwen/Qwen3-VL-4B-Instruct"
QWEN_3_VL_4B_FP8 = "Qwen/Qwen3-VL-4B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf" LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
...@@ -124,6 +130,8 @@ QWEN_VL_MODELS = [ ...@@ -124,6 +130,8 @@ QWEN_VL_MODELS = [
SupportedModels.QWEN_3_VL_30B_A3B, SupportedModels.QWEN_3_VL_30B_A3B,
SupportedModels.QWEN_3_VL_30B_A3B_FP8, SupportedModels.QWEN_3_VL_30B_A3B_FP8,
SupportedModels.QWEN_3_VL_8B_FP8, SupportedModels.QWEN_3_VL_8B_FP8,
SupportedModels.QWEN_3_VL_4B,
SupportedModels.QWEN_3_VL_4B_FP8,
] ]
...@@ -159,7 +167,7 @@ def load_vision_model(model_id: str) -> torch.nn.Module: ...@@ -159,7 +167,7 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
# Load only the vision model via vLLM # Load only the vision model via vLLM
vllm_model = LLM( vllm_model = LLM(
model=model_id, model=model_id,
enforce_eager=True, enforce_eager=False,
kv_cache_memory_bytes=1024 kv_cache_memory_bytes=1024
* 1024 * 1024
* 8, # 8MB KV cache for vLLM to complete the init lifecycle, encoder-only doesn't require KV cache. * 8, # 8MB KV cache for vLLM to complete the init lifecycle, encoder-only doesn't require KV cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment