echo_tensor_worker.py 3.82 KB
Newer Older
1
#  SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#  SPDX-License-Identifier: Apache-2.0

# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.


7
8
9
# Knowing the test will be run in environment that has tritonclient installed,
# which contain the generated file equivalent to model_config.proto.
import tritonclient.grpc.model_config_pb2 as mc
10
11
import uvloop

12
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model
13
14
15
from dynamo.runtime import DistributedRuntime, dynamo_worker


16
@dynamo_worker()
17
async def echo_tensor_worker(runtime: DistributedRuntime):
18
    endpoint = runtime.endpoint("tensor.echo.generate")
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    triton_model_config = mc.ModelConfig()
    triton_model_config.name = "echo"
    triton_model_config.platform = "custom"
    input_tensor = triton_model_config.input.add()
    input_tensor.name = "input"
    input_tensor.data_type = mc.TYPE_STRING
    input_tensor.dims.extend([-1])
    optional_input_tensor = triton_model_config.input.add()
    optional_input_tensor.name = "optional_input"
    optional_input_tensor.data_type = mc.TYPE_INT32
    optional_input_tensor.dims.extend([-1])
    optional_input_tensor.optional = True
    output_tensor = triton_model_config.output.add()
    output_tensor.name = "dummy_output"
    output_tensor.data_type = mc.TYPE_STRING
    output_tensor.dims.extend([-1])
    triton_model_config.model_transaction_policy.decoupled = True

38
    model_config = {
39
40
41
42
        "name": "",
        "inputs": [],
        "outputs": [],
        "triton_model_config": triton_model_config.SerializeToString(),
43
44
45
46
    }
    runtime_config = ModelRuntimeConfig()
    runtime_config.set_tensor_model_config(model_config)

47
48
    # Internally the bytes string will be converted to List of int
    retrieved_model_config = runtime_config.get_tensor_model_config()
49
    assert retrieved_model_config is not None
50
51
52
53
    retrieved_model_config["triton_model_config"] = bytes(
        retrieved_model_config["triton_model_config"]
    )
    assert model_config == retrieved_model_config
54

55
56
    # Use register_model for tensor-based backends (skips HuggingFace downloads)
    await register_model(
57
58
59
        ModelInput.Tensor,
        ModelType.TensorBased,
        endpoint,
60
        "echo",  # model_path (used as display name for tensor-based models)
61
62
63
64
65
66
        runtime_config=runtime_config,
    )

    await endpoint.serve_endpoint(generate)


67
async def generate(request):
68
    """Echo tensors and parameters back to the client."""
69
70
71
    # [NOTE] gluo: currently there is no frontend side
    # validation between model config and actual request,
    # so any request will reach here and be echoed back.
72
    print(f"Echoing request: {request}")
73
74
75
76

    params = {}
    if "parameters" in request:
        params.update(request["parameters"])
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        if "malformed_response" in request["parameters"]:
            request["tensors"][0]["data"] = {"values": [0, 1, 2]}
            yield {
                "model": request["model"],
                "tensors": request["tensors"],
                "parameters": params,
            }
            return
        elif "data_mismatch" in request["parameters"]:
            # Modify the data type to trigger data mismatch error
            request["tensors"][0]["data"]["values"] = []
            yield {
                "model": request["model"],
                "tensors": request["tensors"],
                "parameters": params,
            }
            return
        elif "raise_exception" in request["parameters"]:
            raise ValueError("Intentional exception raised by echo_tensor_worker.")
96
97
98
99
100
101
102
103

    params["processed"] = {"bool": True}

    yield {
        "model": request["model"],
        "tensors": request["tensors"],
        "parameters": params,
    }
104
105
106
107


if __name__ == "__main__":
    uvloop.run(echo_tensor_worker())