template_verifier.py 2.35 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
import os
5
6
7
8
9
10
import sys
from pathlib import Path

import uvloop
from transformers import AutoTokenizer

11
from dynamo.common.utils.paths import WORKSPACE_DIR
12
from dynamo.llm import ModelInput, ModelType, register_model
13
14
from dynamo.runtime import DistributedRuntime, dynamo_worker

15
SERVE_TEST_DIR = os.path.join(WORKSPACE_DIR, "tests/serve")
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


class TemplateVerificationHandler:
    """Handler to verify custom template application during preprocessing."""

    def __init__(self, model_name="Qwen/Qwen3-0.6B"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.template_marker = "CUSTOM_TEMPLATE_ACTIVE|"

    async def generate(self, request, context):
        """Check for template marker and return tokenized response."""
        token_ids = request.get("token_ids", [])
        decoded = self.tokenizer.decode(token_ids)

        # Check if the custom template marker is present
        if self.template_marker in decoded:
            response_text = "Successfully Applied Chat Template"
        else:
            response_text = "Failed to Apply Chat Template"

        # Return tokenized response for frontend to detokenize
        response_tokens = self.tokenizer.encode(response_text, add_special_tokens=False)
        yield {"token_ids": response_tokens}


41
@dynamo_worker()
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
async def main(runtime: DistributedRuntime):
    """Main worker function for template verification."""

    # Create service
    component = runtime.namespace("test").component("backend")
    endpoint = component.endpoint("generate")

    # Use the existing custom template from fixtures
    template_path = Path(SERVE_TEST_DIR) / "fixtures" / "custom_template.jinja"
    if not template_path.exists():
        print(f"Error: Template not found at {template_path}")
        sys.exit(1)

    # Register model with custom template
    model_name = "Qwen/Qwen3-0.6B"
57
    await register_model(
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        ModelInput.Tokens,
        ModelType.Chat,
        endpoint,
        model_name,
        model_name=model_name,
        custom_template_path=str(template_path),
    )

    # Create handler and serve
    handler = TemplateVerificationHandler(model_name)
    await endpoint.serve_endpoint(handler.generate)


if __name__ == "__main__":
    uvloop.run(main())