echo_tensor_worker.py 3.04 KB
Newer Older
1
#  SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#  SPDX-License-Identifier: Apache-2.0

# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.


7
8
9
# Knowing the test will be run in environment that has tritonclient installed,
# which contain the generated file equivalent to model_config.proto.
import tritonclient.grpc.model_config_pb2 as mc
10
11
12
13
14
15
import uvloop

from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker


16
@dynamo_worker()
17
18
19
20
21
async def echo_tensor_worker(runtime: DistributedRuntime):
    component = runtime.namespace("tensor").component("echo")

    endpoint = component.endpoint("generate")

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    triton_model_config = mc.ModelConfig()
    triton_model_config.name = "echo"
    triton_model_config.platform = "custom"
    input_tensor = triton_model_config.input.add()
    input_tensor.name = "input"
    input_tensor.data_type = mc.TYPE_STRING
    input_tensor.dims.extend([-1])
    optional_input_tensor = triton_model_config.input.add()
    optional_input_tensor.name = "optional_input"
    optional_input_tensor.data_type = mc.TYPE_INT32
    optional_input_tensor.dims.extend([-1])
    optional_input_tensor.optional = True
    output_tensor = triton_model_config.output.add()
    output_tensor.name = "dummy_output"
    output_tensor.data_type = mc.TYPE_STRING
    output_tensor.dims.extend([-1])
    triton_model_config.model_transaction_policy.decoupled = True

40
    model_config = {
41
42
43
44
        "name": "",
        "inputs": [],
        "outputs": [],
        "triton_model_config": triton_model_config.SerializeToString(),
45
46
47
48
    }
    runtime_config = ModelRuntimeConfig()
    runtime_config.set_tensor_model_config(model_config)

49
50
51
52
53
54
    # Internally the bytes string will be converted to List of int
    retrieved_model_config = runtime_config.get_tensor_model_config()
    retrieved_model_config["triton_model_config"] = bytes(
        retrieved_model_config["triton_model_config"]
    )
    assert model_config == retrieved_model_config
55

56
    # Use register_llm for tensor-based backends (skips HuggingFace downloads)
57
58
59
60
    await register_llm(
        ModelInput.Tensor,
        ModelType.TensorBased,
        endpoint,
61
        "echo",  # model_path (used as display name for tensor-based models)
62
63
64
65
66
67
68
        runtime_config=runtime_config,
    )

    await endpoint.serve_endpoint(generate)


async def generate(request, context):
69
    """Echo tensors and parameters back to the client."""
70
71
72
    # [NOTE] gluo: currently there is no frontend side
    # validation between model config and actual request,
    # so any request will reach here and be echoed back.
73
    print(f"Echoing request: {request}")
74
75
76
77
78
79
80
81
82
83
84
85

    params = {}
    if "parameters" in request:
        params.update(request["parameters"])

    params["processed"] = {"bool": True}

    yield {
        "model": request["model"],
        "tensors": request["tensors"],
        "parameters": params,
    }
86
87
88
89


if __name__ == "__main__":
    uvloop.run(echo_tensor_worker())