sender.py 4.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging

import torch
import uvloop
from protocol import BatchTransferRequest, EmbeddingTransferMode, TransferConfig

from dynamo.common.multimodal.embedding_transfer import (
    LocalEmbeddingSender,
    NixlReadEmbeddingSender,
    NixlWriteEmbeddingSender,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging

logger = logging.getLogger(__name__)
configure_dynamo_logging()


class Sender:
    def __init__(self, runtime: DistributedRuntime):
        self.runtime = runtime
        self.local_sender = LocalEmbeddingSender()
        self.read_sender = NixlReadEmbeddingSender()
        self.write_sender = NixlWriteEmbeddingSender()
        # GPU tensor to mimic encoder output
        self.cpu_tensor = torch.randn([256, 8 * 1024], dtype=torch.float16)
        self.gpu_tensor = (
            torch.randn([256, 8 * 1024], dtype=torch.float16, device="cuda")
            if torch.cuda.is_available()
            else None
        )
        self.config = TransferConfig()

    def get_run_config(self):
        if self.config.use_gpu and self.gpu_tensor is None:
            raise RuntimeError("GPU mode requested but CUDA is not available.")
        # Select the variant of sender/receiver based on config
        if self.config.transfer_type == EmbeddingTransferMode.LOCAL:
            sender = self.local_sender
        elif self.config.transfer_type == EmbeddingTransferMode.NIXL_WRITE:
            sender = self.write_sender
        elif self.config.transfer_type == EmbeddingTransferMode.NIXL_READ:
            sender = self.read_sender
        else:
            raise ValueError(f"Invalid transfer type: {self.config.transfer_type}")
        tensor = self.gpu_tensor if self.config.use_gpu else self.cpu_tensor
        tensor_count = self.config.tensor_count_per_request
        return sender, tensor, tensor_count

    async def async_init(self):
        self.receiver_read_endpoint = self.runtime.endpoint(
            "embedding_transfer.receiver.read"
        )
        self.read_client = await self.receiver_read_endpoint.client()
        # await self.read_client.wait_for_instances()

    async def generate(self, request: str):
        # Select the variant of sender/receiver based on config
        sender, tensor, tensor_count = self.get_run_config()

        request = BatchTransferRequest(requests=[])
        futures = []
        for _ in range(tensor_count):
            transfer_request, send_future = await sender.send_embeddings(
                tensor, stage_embeddings=True
            )
            request.requests.append(transfer_request)
            futures.append(send_future)
        stream = await self.read_client.round_robin(request.model_dump_json())
        async for response in stream:
            continue
        await asyncio.gather(*futures)
        yield "done"

    async def write(self, request: str):
        # Select the variant of sender/receiver based on config
        sender, tensor, tensor_count = self.get_run_config()

        response = BatchTransferRequest(requests=[])
        futures = []
        for _ in range(tensor_count):
            transfer_request, send_future = await sender.send_embeddings(
                tensor, stage_embeddings=True
            )
            response.requests.append(transfer_request)
            futures.append(send_future)
        yield response.model_dump_json()
        await asyncio.gather(*futures)

    async def update_config(self, request: str):
        request = TransferConfig.model_validate_json(request)
        self.config = request
        yield "config updated"


@dynamo_worker()
async def worker(runtime: DistributedRuntime):
    namespace_name = "embedding_transfer"
    component_name = "sender"
    worker = Sender(runtime)
    await worker.async_init()

    logger.info(f"Created service {namespace_name}/{component_name}")
    logger.info(f"Serving endpoint {namespace_name}.{component_name}.generate")
    logger.info(f"Serving endpoint {namespace_name}.{component_name}.write")
    logger.info(f"Serving endpoint {namespace_name}.{component_name}.update_config")
    generate_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.generate")
    write_endpoint = runtime.endpoint(f"{namespace_name}.{component_name}.write")
    update_config_endpoint = runtime.endpoint(
        f"{namespace_name}.{component_name}.update_config"
    )
    await asyncio.gather(
        *[
            generate_endpoint.serve_endpoint(worker.generate),
            write_endpoint.serve_endpoint(worker.write),
            update_config_endpoint.serve_endpoint(worker.update_config),
        ]
    )


if __name__ == "__main__":
    uvloop.install()
    asyncio.run(worker())