hf.py 23.1 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
4
5
import itertools
from collections import defaultdict, deque
6
7
from collections.abc import Set
from functools import lru_cache
8
from typing import TYPE_CHECKING, Any, cast
9
10
11
12
13
14
15
16

import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox

17
from vllm.config import ModelConfig, VllmConfig
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormat,
    ChatTemplateContentFormatOption,
    ChatTemplateResolutionError,
    ConversationMessage,
    load_chat_template,
    parse_chat_messages,
    parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw

35
from .base import BaseRenderer
36
37
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
38
from .params import ChatParams
39

40
41
42
43
44
45
46
if TYPE_CHECKING:
    from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
else:
    MultiModalDataDict = dict[str, Any]
    MultiModalUUIDDict = dict[str, Any]


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
logger = init_logger(__name__)


_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.

This is needed because `lru_cache` does not cache when an exception happens.
"""


def _try_get_processor_chat_template(
    tokenizer: HfTokenizer,
    *,
    trust_remote_code: bool,
) -> str | None:
    cache_key = (tokenizer.name_or_path, trust_remote_code)
    if cache_key in _PROCESSOR_CHAT_TEMPLATES:
        return _PROCESSOR_CHAT_TEMPLATES[cache_key]

    from transformers import (
        PreTrainedTokenizer,
        PreTrainedTokenizerFast,
        ProcessorMixin,
    )

    try:
        processor = cached_get_processor(
            tokenizer.name_or_path,
            processor_cls=(
                PreTrainedTokenizer,
                PreTrainedTokenizerFast,
                ProcessorMixin,
            ),
            trust_remote_code=trust_remote_code,
        )
        if (
            isinstance(processor, ProcessorMixin)
            and hasattr(processor, "chat_template")
            and (chat_template := processor.chat_template) is not None
        ):
            _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
            return chat_template
    except Exception:
        logger.debug(
            "Failed to load AutoProcessor chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    _PROCESSOR_CHAT_TEMPLATES[cache_key] = None
    return None


def resolve_chat_template(
    tokenizer: HfTokenizer,
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
    *,
    model_config: "ModelConfig",
) -> str | None:
    # 1st priority: The given chat template
    if chat_template is not None:
        return chat_template

    # 2nd priority: AutoProcessor chat template, unless tool calling is enabled
    if tools is None:
        chat_template = _try_get_processor_chat_template(
            tokenizer,
            trust_remote_code=model_config.trust_remote_code,
        )
        if chat_template is not None:
            return chat_template

    # 3rd priority: AutoTokenizer chat template
    try:
        return tokenizer.get_chat_template(chat_template, tools=tools)
    except Exception:
        logger.debug(
            "Failed to load AutoTokenizer chat template for %s",
            tokenizer.name_or_path,
            exc_info=True,
        )

    # 4th priority: Predefined fallbacks
    path = get_chat_template_fallback_path(
        model_type=model_config.hf_config.model_type,
        tokenizer_name_or_path=tokenizer.name_or_path,
    )
    if path is not None:
        logger.info_once(
            "Loading chat template fallback for %s as there isn't one "
            "defined on HF Hub.",
            tokenizer.name_or_path,
        )
        chat_template = load_chat_template(path)
    else:
        logger.debug_once(
            "There is no chat template fallback for %s", tokenizer.name_or_path
        )

    return chat_template


def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
    if isinstance(node, jinja2.nodes.Name):
        return node.ctx == "load" and node.name == varname

    return False


def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
    if isinstance(node, jinja2.nodes.Getitem):
        return (
            _is_var_access(node.node, varname)
            and isinstance(node.arg, jinja2.nodes.Const)
            and node.arg.value == key
        )

    if isinstance(node, jinja2.nodes.Getattr):
        return _is_var_access(node.node, varname) and node.attr == key

    return False


def _is_var_or_elems_access(
    node: jinja2.nodes.Node,
    varname: str,
    key: str | None = None,
) -> bool:
    if isinstance(node, jinja2.nodes.Filter):
        return node.node is not None and _is_var_or_elems_access(
            node.node, varname, key
        )
    if isinstance(node, jinja2.nodes.Test):
        return _is_var_or_elems_access(node.node, varname, key)

    if isinstance(node, jinja2.nodes.Getitem) and isinstance(
        node.arg, jinja2.nodes.Slice
    ):
        return _is_var_or_elems_access(node.node, varname, key)

    return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)


def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
    # Global variable that is implicitly defined at the root
    yield root, varname

    # Iterative BFS
    related_varnames = deque([varname])
    while related_varnames:
        related_varname = related_varnames.popleft()

        for assign_ast in root.find_all(jinja2.nodes.Assign):
            lhs = assign_ast.target
            rhs = assign_ast.node

            if _is_var_or_elems_access(rhs, related_varname):
                assert isinstance(lhs, jinja2.nodes.Name)
                yield assign_ast, lhs.name

                # Avoid infinite looping for self-assignment
                if lhs.name != related_varname:
                    related_varnames.append(lhs.name)


# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
    messages_varnames = [
        varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
    ]

    # Search for {%- for message in messages -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in messages_varnames:
            if _is_var_or_elems_access(loop_iter, varname):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
    message_varnames = [
        varname for _, varname in _iter_nodes_assign_messages_item(root)
    ]

    # Search for {%- for content in message['content'] -%} loops
    for loop_ast in root.find_all(jinja2.nodes.For):
        loop_iter = loop_ast.iter
        loop_target = loop_ast.target

        for varname in message_varnames:
            if _is_var_or_elems_access(loop_iter, varname, "content"):
                assert isinstance(loop_target, jinja2.nodes.Name)
                yield loop_ast, loop_target.name
                break


def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
    import transformers.utils.chat_template_utils as hf_chat_utils

    try:
        jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
        return jinja_compiled.environment.parse(chat_template)
    except Exception:
        logger.exception("Error when compiling Jinja template")
        return None


@lru_cache(maxsize=32)
def _detect_content_format(
    chat_template: str,
    *,
    default: ChatTemplateContentFormat,
) -> ChatTemplateContentFormat:
    jinja_ast = _try_extract_ast(chat_template)
    if jinja_ast is None:
        return default

    try:
        next(_iter_nodes_assign_content_item(jinja_ast))
    except StopIteration:
        return "string"
    except Exception:
        logger.exception("Error when parsing AST of Jinja template")
        return default
    else:
        return "openai"


def _resolve_chat_template_content_format(
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
    tokenizer: HfTokenizer,
    *,
    model_config: "ModelConfig",
) -> ChatTemplateContentFormat:
    resolved_chat_template = resolve_chat_template(
        tokenizer,
        chat_template=chat_template,
        tools=tools,
        model_config=model_config,
    )

    jinja_text = (
        resolved_chat_template
        if isinstance(resolved_chat_template, str)
        else load_chat_template(chat_template, is_literal=True)
    )

    detected_format = (
        "string"
        if jinja_text is None
        else _detect_content_format(jinja_text, default="string")
    )

    return detected_format


@lru_cache
def _log_chat_template_content_format(
    chat_template: str | None,  # For caching purposes
    given_format: ChatTemplateContentFormatOption,
    detected_format: ChatTemplateContentFormatOption,
):
    logger.info(
        "Detected the chat template content format to be '%s'. "
        "You can set `--chat-template-content-format` to override this.",
        detected_format,
    )

    if given_format != "auto" and given_format != detected_format:
        logger.warning(
            "You specified `--chat-template-content-format %s` "
            "which is different from the detected format '%s'. "
            "If our automatic detection is incorrect, please consider "
            "opening a GitHub issue so that we can improve it: "
            "https://github.com/vllm-project/vllm/issues/new/choose",
            given_format,
            detected_format,
        )


def resolve_chat_template_content_format(
    chat_template: str | None,
    tools: list[dict[str, Any]] | None,
    given_format: ChatTemplateContentFormatOption,
    tokenizer: HfTokenizer,
    *,
    model_config: "ModelConfig",
) -> ChatTemplateContentFormat:
    if given_format != "auto":
        return given_format

    detected_format = _resolve_chat_template_content_format(
        chat_template,
        tools,
        tokenizer,
        model_config=model_config,
    )

    _log_chat_template_content_format(
        chat_template,
        given_format=given_format,
        detected_format=detected_format,
    )

    return detected_format


# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
    tags = {"generation"}

    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node:
        lineno = next(parser.stream).lineno
        body = parser.parse_statements(("name:endgeneration",), drop_needle=True)
        call = self.call_method("_generation_support")
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
        return call_block.set_lineno(lineno)


def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]:
    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
        trim_blocks=True,
        lstrip_blocks=True,
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
    )
    parsed_content = env.parse(chat_template)
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
    return template_vars


_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)


@lru_cache
def _get_hf_base_chat_template_params() -> frozenset[str]:
    from transformers import PreTrainedTokenizer

    # Get standard parameters from HuggingFace's base tokenizer class.
    # This dynamically extracts parameters from PreTrainedTokenizer's
    # apply_chat_template method, ensuring compatibility with tokenizers
    # that use **kwargs to receive standard parameters.

    # Read signature from HF's base class - the single source of truth
    base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)

    # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
    return frozenset(
        p.name
        for p in base_sig.parameters.values()
        if p.kind
        not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
    )


def resolve_chat_template_kwargs(
    tokenizer: HfTokenizer,
    chat_template: str,
    chat_template_kwargs: dict[str, Any],
    raise_on_unexpected: bool = True,
) -> dict[str, Any]:
    # We exclude chat_template from kwargs here, because
    # chat template has been already resolved at this stage
    unexpected_vars = {"chat_template", "tokenize"}
    if raise_on_unexpected and (
        unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
    ):
        raise ValueError(
            "Found unexpected chat template kwargs from request: "
            f"{unexpected_in_kwargs}"
        )

    fn_kw = {
        k
        for k in chat_template_kwargs
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
    }
    template_vars = _cached_resolve_chat_template_kwargs(chat_template)

    # Allow standard HF parameters even if tokenizer uses **kwargs to receive them
    hf_base_params = _get_hf_base_chat_template_params()

    accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
    return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}


def safe_apply_chat_template(
    model_config: "ModelConfig",
    tokenizer: HfTokenizer,
    conversation: list[ConversationMessage],
    *,
    tools: list[dict[str, Any]] | None = None,
    chat_template: str | None = None,
    tokenize: bool = True,
    **kwargs,
) -> str | list[int]:
    chat_template = resolve_chat_template(
        tokenizer,
        chat_template=chat_template,
        tools=tools,
        model_config=model_config,
    )
    if chat_template is None:
        raise ChatTemplateResolutionError(
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."
        )

    resolved_kwargs = resolve_chat_template_kwargs(
        tokenizer=tokenizer,
        chat_template=chat_template,
        chat_template_kwargs=kwargs,
    )

    try:
        return tokenizer.apply_chat_template(
            conversation=conversation,  # type: ignore[arg-type]
            tools=tools,  # type: ignore[arg-type]
            chat_template=chat_template,
            tokenize=tokenize,
            **resolved_kwargs,
        )
    # External library exceptions can sometimes occur despite the framework's
    # internal exception management capabilities.
    except Exception as e:
        # Log and report any library-related exceptions for further
        # investigation.
        logger.exception(
            "An error occurred in `transformers` while applying chat template"
        )
        raise ValueError(str(e)) from e


490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
def rebuild_mm_uuids_from_mm_data(
    mm_uuids: "MultiModalUUIDDict",
    mm_data: "MultiModalDataDict",
) -> "MultiModalUUIDDict":
    """Rebuild mm_uuids after vision_chunk processing.

    When videos are split into chunks, the original UUIDs need to be updated
    to reflect the new UUIDs generated for each chunk.

    Args:
        mm_uuids: Original UUIDs dictionary
        mm_data: Processed multimodal data with vision_chunk items

    Returns:
        Updated UUIDs dictionary with chunk UUIDs
    """
    vision_chunks = mm_data.get("vision_chunk")
    if vision_chunks is None:
        return mm_uuids

    assert all(isinstance(item, dict) for item in vision_chunks), (
        "Expected all vision_chunk items to be dicts"
    )
    vision_chunks = cast(list[dict[str, Any]], vision_chunks)
    vision_chunk_uuids = [
        uuid_val for item in vision_chunks if (uuid_val := item.get("uuid")) is not None
    ]

    if vision_chunk_uuids:
        mm_uuids = dict(mm_uuids)
        mm_uuids["vision_chunk"] = vision_chunk_uuids

    return mm_uuids


def build_video_prompts_from_mm_data(
    mm_data: "MultiModalDataDict",
) -> list[str]:
    """Build video prompts from vision_chunk data.

    Collects prompts from video chunks and groups them by video_idx.

    Args:
        mm_data: Processed multimodal data with vision_chunk items

    Returns:
        List of video prompts, one per video.
    """
    vision_chunks = mm_data.get("vision_chunk")
    if vision_chunks is None:
        return []

    # Group chunks by video_idx
    video_prompts_dict: dict[int, list[str]] = defaultdict(list)

    for item in vision_chunks:
        # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
        assert isinstance(item, dict)
        if item.get("type") == "video_chunk":
            video_idx = item.get("video_idx", 0)
            prompt = item.get("prompt", "")
            video_prompts_dict[video_idx].append(prompt)

    # Build prompts in video order
    video_prompts = [
        "".join(video_prompts_dict[video_idx])
        for video_idx in sorted(video_prompts_dict.keys())
    ]

    return video_prompts


def replace_vision_chunk_video_placeholder(
    prompt_raw: str | list[int],
    mm_data: "MultiModalDataDict",
    video_placeholder: str | None,
) -> str | list[int]:
Jiayi Yan's avatar
Jiayi Yan committed
567
    # get video placeholder, replace it with runtime video-chunk prompts
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    if video_placeholder and isinstance(prompt_raw, str):
        video_prompts = build_video_prompts_from_mm_data(mm_data)

        # replace in order
        prompt_raw_parts = prompt_raw.split(video_placeholder)
        if len(prompt_raw_parts) == len(video_prompts) + 1:
            prompt_raw = "".join(
                itertools.chain.from_iterable(zip(prompt_raw_parts, video_prompts))
            )
            prompt_raw += prompt_raw_parts[-1]
        else:
            logger.warning(
                "Number of video placeholders (%d) does not match "
                "number of videos (%d) in the request.",
                len(prompt_raw_parts) - 1,
                len(video_prompts),
            )
    return prompt_raw


588
class HfRenderer(BaseRenderer[HfTokenizer]):
589
    @classmethod
590
    def from_config(  # type: ignore[override]
591
        cls,
592
        config: VllmConfig,
593
        tokenizer_kwargs: dict[str, Any],
594
595
    ) -> "HfRenderer":
        model_config = config.model_config
596
        if model_config.skip_tokenizer_init:
597
598
599
600
601
602
603
604
605
606
            tokenizer = None
        else:
            tokenizer = cast(
                HfTokenizer,
                cached_get_tokenizer(
                    tokenizer_cls=CachedHfTokenizer,  # type: ignore[type-abstract]
                    **tokenizer_kwargs,
                ),
            )

607
        return cls(config, tokenizer)
608

609
610
611
612
613
614
    def __init__(
        self,
        config: VllmConfig,
        tokenizer: HfTokenizer | None,
    ) -> None:
        super().__init__(config, tokenizer)
615

616
617
618
        self.use_unified_vision_chunk = getattr(
            config.model_config.hf_config, "use_unified_vision_chunk", False
        )
619
620
621
622

    def render_messages(
        self,
        messages: list[ChatCompletionMessageParam],
623
        params: ChatParams,
624
    ) -> tuple[list[ConversationMessage], DictPrompt]:
625
        model_config = self.model_config
626
627
628
629
630
631
        tokenizer = self.get_tokenizer()

        conversation, mm_data, mm_uuids = parse_chat_messages(
            messages,
            model_config,
            content_format=resolve_chat_template_content_format(
632
633
634
                chat_template=params.chat_template,
                tools=params.chat_template_kwargs.get("tools"),
                given_format=params.chat_template_content_format,
635
636
637
                tokenizer=tokenizer,
                model_config=model_config,
            ),
638
            media_io_kwargs=params.media_io_kwargs,
639
640
641
642
643
644
        )

        prompt_raw = safe_apply_chat_template(
            model_config,
            tokenizer,
            conversation,
645
            **params.get_apply_chat_template_kwargs(),
646
647
        )

Roger Wang's avatar
Roger Wang committed
648
649
650
        # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
        # model which uses unified vision chunks for both images and videos.
        if (
651
            self.use_unified_vision_chunk
Roger Wang's avatar
Roger Wang committed
652
653
654
655
656
            and mm_uuids is not None
            and mm_data is not None
        ):
            mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)

657
            # get video placeholder, replace it with runtime video-chunk prompts
Roger Wang's avatar
Roger Wang committed
658
659
660
            video_placeholder = getattr(
                model_config.hf_config, "video_placeholder", None
            )
661
662
663
664
665
            prompt_raw = replace_vision_chunk_video_placeholder(
                prompt_raw,
                mm_data,
                video_placeholder,
            )
Roger Wang's avatar
Roger Wang committed
666

667
        prompt = parse_dec_only_prompt(prompt_raw)
668
669
670
671
672
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

673
        return conversation, prompt
674
675
676
677

    async def render_messages_async(
        self,
        messages: list[ChatCompletionMessageParam],
678
        params: ChatParams,
679
    ) -> tuple[list[ConversationMessage], DictPrompt]:
680
        model_config = self.model_config
681
682
683
684
685
686
        tokenizer = self.get_tokenizer()

        conversation, mm_data, mm_uuids = await parse_chat_messages_async(
            messages,
            model_config,
            content_format=resolve_chat_template_content_format(
687
688
689
                chat_template=params.chat_template,
                tools=params.chat_template_kwargs.get("tools"),
                given_format=params.chat_template_content_format,
690
691
692
                tokenizer=tokenizer,
                model_config=model_config,
            ),
693
            media_io_kwargs=params.media_io_kwargs,
694
695
696
697
698
699
        )

        prompt_raw = safe_apply_chat_template(
            model_config,
            tokenizer,
            conversation,
700
            **params.get_apply_chat_template_kwargs(),
701
702
        )

Roger Wang's avatar
Roger Wang committed
703
704
705
        # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
        # model which uses unified vision chunks for both images and videos.
        if (
706
            self.use_unified_vision_chunk
Roger Wang's avatar
Roger Wang committed
707
708
709
            and mm_uuids is not None
            and mm_data is not None
        ):
710
            # get video placeholder, replace it with runtime video-chunk prompts
Roger Wang's avatar
Roger Wang committed
711
712
713
            video_placeholder = getattr(
                model_config.hf_config, "video_placeholder", None
            )
714
715
716
717
718
            prompt_raw = replace_vision_chunk_video_placeholder(
                prompt_raw,
                mm_data,
                video_placeholder,
            )
Roger Wang's avatar
Roger Wang committed
719

720
        prompt = parse_dec_only_prompt(prompt_raw)
721
722
723
724
725
        if mm_data is not None:
            prompt["multi_modal_data"] = mm_data
        if mm_uuids is not None:
            prompt["multi_modal_uuids"] = mm_uuids

726
        return conversation, prompt