encode_worker.py 6.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from io import BytesIO
18
from queue import Queue
19
20
from typing import AsyncIterator

21
import connect
22
23
24
25
26
27
28
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args

29
from dynamo.sdk import async_on_start, endpoint, service
30
31
32

logger = logging.getLogger(__name__)

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
try:
    import cupy as array_module

    if not array_module.cuda.is_available():
        raise ImportError("CUDA is not available.")
    DEVICE = "cuda"
    logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
    logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
    import numpy as array_module

    DEVICE = "cpu"

CACHE_SIZE_MAXIMUM = 8

48
49
50
51
52
53
54
55

@service(
    dynamo={
        "namespace": "dynamo",
    },
    resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
    workers=1,
)
56
class VllmEncodeWorker:
57
58
59
60
61
62
63
64
65
66
67
68
69
    def __init__(self) -> None:
        class_name = self.__class__.__name__
        self.engine_args = parse_vllm_args(class_name, "")
        self.MODEL_ID = self.engine_args.model

        self.image_processor = AutoImageProcessor.from_pretrained(
            self.MODEL_ID, trust_remote_code=True
        )

        self.vision_model = LlavaForConditionalGeneration.from_pretrained(
            self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
        ).eval()

70
71
72
        self._image_cache: dict[str, Image.Image] = {}
        self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)

73
    @endpoint()
74
    async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        logger.debug(
            f"Received encode request: {{ id: {request.request_id}, image_url: '{request.image_url}' }}."
        )

        request_id = request.request_id
        image_url = request.image_url.lower()

        # The following steps encode the requested image and provided useful embeddings.
        # 1. Open the image from the provided URL.
        # 2. Process the image using the image processor.
        # 3. Run the image through the vision model's vision tower.
        # 4. Run the results of the vision tower through the multi-modal projector.
        # 5. Create a descriptor for the embeddings.
        # 6. Create a write operation using the serialized request and the descriptor.
        # 7. Await for the write operation to complete.
        # 8. Yield the encode response.

        # Either retrieve the image from the cache or download it and then cache it.
        if request.image_url in self._image_cache:
            image = self._image_cache[image_url]
            logger.debug(
                f"Image found in cache for request: {{ id: {request_id}, image_url: '{image_url}' }}."
            )
        else:
            image = self.open_image(image_url)
            logger.debug(
                f"Downloading/opening image for request: {{ id: {request_id}, image_url: '{image_url}' }}."
            )
            # Cache the image for future use, and evict the oldest image if the cache is full.
            if self._cache_queue.full():
                oldest_image_url = self._cache_queue.get()
                del self._image_cache[oldest_image_url]

            self._image_cache[request.image_url] = image
            self._cache_queue.put(request.image_url)

        logger.debug(
            f"Processing image for request: {{ id: {request_id}, image_url: '{image_url}' }}"
        )
114
115
116
117
118
119
120
        image_embeds = self.image_processor(images=image, return_tensors="pt")

        with torch.no_grad():
            logger.debug(f"Vision model device: {self.vision_model.device}")
            vision_outputs = self.vision_model.vision_tower(
                image_embeds["pixel_values"].to(self.vision_model.device)
            )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            logger.debug("Vision model completed.")

            embeddings = vision_outputs.last_hidden_state
            embeddings = self.vision_model.multi_modal_projector(embeddings)

            logger.debug(
                f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
            )

            if request.serialized_request is None:
                logger.error(
                    f"Request serialized_request is None for request: {{ id: {request_id}, image_url: '{image_url}' }}."
                )

            # Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
            descriptor = connect.Descriptor(embeddings)
            # Create a write operation using the serialized request and the descriptor.
            # This will begin the RDMA transfer of the embeddings to the remote worker.
            write_op = await self._connector.begin_write(
                descriptor,
                request.serialized_request,
            )
            # Await for the write operation to complete.
            # This will block until the data has been written to the remote worker or an error occurs.
            await write_op.wait_for_completion()
146
147

            yield EncodeResponse(
148
                request_id=request.request_id,
149
150
            ).model_dump_json()

151
152
    @async_on_start
    async def async_init(self):
153
154
155
156
157
158
159
        logger.info("Startup started.")
        # Create and initialize a dynamo connector for this worker.
        # We'll needs this to move data between this worker and remote workers efficiently.
        self._connector = connect.Connector()
        await self._connector.initialize()
        logger.info("Startup completed.")

160
161
162
    def open_image(self, image: str) -> Image.Image:
        # TODO: Have a seperate field for url and non url - and avoid auto detection
        try:
163
            # Acquire the image and convert it to the format (RGB) the image processor model expects.
164
165
166
167
168
            if image.startswith("http") or image.startswith("https"):
                response = requests.get(image)
                image_data = Image.open(BytesIO(response.content)).convert("RGB")
            else:
                image_data = Image.open(image).convert("RGB")
169
170

            return image_data
171
172
173
        except Exception as e:
            logger.error(f"Error opening image: {e}")
            raise e