Unverified Commit 80955ef4 authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

perf: Keep embeddings on GPU Embedding Sender in EPD pipeline + minor fixes (#6535)

parent 5734f5c4
......@@ -361,13 +361,11 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
"""
# If not staging embedding and embedding is on CPU, we explicitly copy
# the tensor as torch.Tensor.cpu() will return original tensor if it's already on CPU
if not stage_embeddings and not embeddings.is_cuda:
embeddings_cpu = embeddings.clone().detach()
if stage_embeddings:
transfer_buf = embeddings
else:
embeddings_cpu = embeddings.cpu()
descriptor = nixl_connect.Descriptor(embeddings_cpu)
transfer_buf = embeddings.clone().detach()
descriptor = nixl_connect.Descriptor(transfer_buf)
readable_op = await self.connector.create_readable(descriptor)
request = TransferRequest(
......
......@@ -42,7 +42,7 @@ ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
class EmbeddingItem:
key: str
image_grid_thw: list
embeddings_cpu: torch.Tensor
embeddings: torch.Tensor
class EncodeWorkerHandler:
......@@ -140,11 +140,11 @@ class EncodeWorkerHandler:
if self.embedding_cache is not None and self.embedding_cache.has_key(
embedding_key
):
(image_grid_thw, embeddings_cpu) = self.embedding_cache.get(
(image_grid_thw, embeddings) = self.embedding_cache.get(
embedding_key
)
embedding_lists[idx] = EmbeddingItem(
embedding_key, image_grid_thw, embeddings_cpu
embedding_key, image_grid_thw, embeddings
)
# compute
else:
......@@ -200,7 +200,7 @@ class EncodeWorkerHandler:
// merge_size
// merge_size
).tolist()
splitted_embeddings = embeddings.cpu().squeeze(0).split(sizes)
splitted_embeddings = embeddings.squeeze(0).split(sizes)
logger.debug(
f"Splitted embeddings lengths: {[e.shape for e in splitted_embeddings]}"
)
......@@ -209,7 +209,7 @@ class EncodeWorkerHandler:
# embeddings already has batch dimension for images, so we can directly
# split by batch dimension
logger.debug(f"image embedding shape: {embeddings.shape}")
splitted_embeddings = embeddings.cpu()
splitted_embeddings = embeddings
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
......@@ -230,7 +230,7 @@ class EncodeWorkerHandler:
embedding_lists[list_idx].key,
(
embedding_lists[list_idx].image_grid_thw,
embedding_lists[list_idx].embeddings_cpu,
embedding_lists[list_idx].embeddings,
),
)
......@@ -240,7 +240,7 @@ class EncodeWorkerHandler:
send_tasks = [
asyncio.create_task(
self.embedding_sender.send_embeddings(
embedding_item.embeddings_cpu, stage_embeddings=True
embedding_item.embeddings, stage_embeddings=True
)
)
for embedding_item in embedding_lists
......@@ -252,7 +252,7 @@ class EncodeWorkerHandler:
for idx, item in enumerate(zip(embedding_lists, transfer_requests)):
embedding_item, transfer_request = item
logger.debug(
f"{embedding_item.embeddings_cpu.shape} prepared for transfer."
f"{embedding_item.embeddings.shape} prepared for transfer."
)
# Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None
......@@ -260,13 +260,13 @@ class EncodeWorkerHandler:
idx
].image_grid_thw = embedding_item.image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple(
embedding_item.embeddings_cpu.shape
embedding_item.embeddings.shape
)
request.multimodal_inputs[idx].serialized_request = transfer_request[0]
# Keep a reference of the embedding_cpu and only drop reference when the transfer is done
# Keep a reference of the embedding and only drop reference when the transfer is done
self.send_complete_queue.put_nowait(
(transfer_request[1], embedding_item.embeddings_cpu)
(transfer_request[1], embedding_item.embeddings)
)
logger.debug(f"Request: {request.model_dump_json()}")
......
......@@ -129,6 +129,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
for item in mm_data.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"])
elif isinstance(item, dict) and "Decoded" in item:
image_urls.append(item["Decoded"])
sampling_params = build_sampling_params(
raw_request, self.default_sampling_params
......
......@@ -33,6 +33,10 @@ VLLM_ENCODER = int(os.getenv("VLLM_ENCODER", 1))
class SupportedModels:
"""Supported multimodal model identifiers"""
# TODO: Replace this explicit model list with dynamic detection using
# HF config `architectures` field or vLLM's model registry, so any
# vLLM-supported VLM works without maintaining entries here.
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
......@@ -42,6 +46,8 @@ class SupportedModels:
QWEN_3_VL_30B_A3B = "Qwen/Qwen3-VL-30B-A3B-Instruct"
QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
QWEN_3_VL_8B_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8"
QWEN_3_VL_4B = "Qwen/Qwen3-VL-4B-Instruct"
QWEN_3_VL_4B_FP8 = "Qwen/Qwen3-VL-4B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
......@@ -124,6 +130,8 @@ QWEN_VL_MODELS = [
SupportedModels.QWEN_3_VL_30B_A3B,
SupportedModels.QWEN_3_VL_30B_A3B_FP8,
SupportedModels.QWEN_3_VL_8B_FP8,
SupportedModels.QWEN_3_VL_4B,
SupportedModels.QWEN_3_VL_4B_FP8,
]
......@@ -159,7 +167,7 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
# Load only the vision model via vLLM
vllm_model = LLM(
model=model_id,
enforce_eager=True,
enforce_eager=False,
kv_cache_memory_bytes=1024
* 1024
* 8, # 8MB KV cache for vLLM to complete the init lifecycle, encoder-only doesn't require KV cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment