# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""vLLM wrapper for Qwen3-TTS model."""

import base64
import io
import urllib.request
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, ClassVar, Literal, Optional
from urllib.parse import urlparse

import librosa
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoProcessor

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsV0Only
from .qwen3_tts_utils.output_templates import OmniOutput
from .qwen3_tts_utils.configuration_qwen3_tts import Qwen3TTSConfig
from .qwen3_tts_utils.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
from .qwen3_tts_utils.processing_qwen3_tts import Qwen3TTSProcessor
from .utils import maybe_prefix

logger = init_logger(__name__)

AudioLike = (
    str  # wav path, URL, base64
    | np.ndarray  # waveform (requires sr)
    | tuple[np.ndarray, int]  # (waveform, sr)
)

MaybeList = Any | list[Any]


@dataclass
class VoiceClonePromptItem:
    """
    Container for one sample's voice-clone prompt information that can be fed to the model.

    Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`.
    """

    ref_code: torch.Tensor | None  # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz
    ref_spk_embedding: torch.Tensor  # (D,)
    x_vector_only_mode: bool
    icl_mode: bool
    ref_text: str | None = None


class Qwen3TTSModelForGeneration(nn.Module, SupportsV0Only):
    """
    vLLM wrapper for Qwen3-TTS model.

    This is a non-autoregressive TTS model that generates audio from text.
    It does not produce token logits like standard LLMs, instead it outputs
    audio waveforms directly.

    Note: This model only supports V0 vLLM engine due to its non-standard
    output format (audio instead of logits).
    """

    # Mark as V0 only - TTS models have non-standard output
    supports_v0_only: ClassVar[Literal[True]] = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config
        model_path = config.model

        # Get dtype from config, default to bfloat16
        torch_dtype = config.dtype
        if torch_dtype is None:
            torch_dtype = torch.bfloat16

        # Check if flash-attn is installed
        try:
            import flash_attn  # noqa: F401
            attn_kwargs = {"attn_implementation": "flash_attention_2"}
        except ImportError:
            logger.warning(
                "Flash-Attn is not installed. "
                "Using default PyTorch attention implementation."
            )
            attn_kwargs = {}

        # Determine device_map from vLLM platform
        from vllm.platforms import current_platform
        if current_platform.is_rocm() or current_platform.is_cuda():
            device_map = current_platform.device_type  # "cuda" for both
        else:
            device_map = None

        # Load the underlying TTS model
        load_kwargs = {"torch_dtype": torch_dtype, **attn_kwargs}
        if device_map is not None:
            load_kwargs["device_map"] = device_map
        self.tts_wrapper = Qwen3TTSModel.from_pretrained(
            model_path,
            **load_kwargs,
        )

        # Infer task type from model path
        self.task_type = self._infer_task_type(model_path)

        # Mark that this model produces multimodal outputs (audio)
        self.have_multimodal_outputs = True

        # Store config for later use
        self.vllm_config = vllm_config
        self.config = config

    def _infer_task_type(self, model_path: str) -> str:
        """Infer task type from model path."""
        path_lower = model_path.lower()
        if "customvoice" in path_lower or "custom_voice" in path_lower:
            return "CustomVoice"
        elif "voicedesign" in path_lower or "voice_design" in path_lower:
            return "VoiceDesign"
        elif "base" in path_lower:
            return "Base"
        # Default: try to extract from path suffix
        suffix = model_path.rstrip("/").split("-")[-1]
        if suffix in ("CustomVoice", "VoiceDesign", "Base"):
            return suffix
        return "CustomVoice"  # Default fallback

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        positions: torch.Tensor | None = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: Any,
    ) -> OmniOutput:
        """
        Forward pass for TTS generation model.

        Unlike standard LLMs, this model generates audio directly from text
        rather than producing token logits for autoregressive generation.

        Args:
            input_ids: Input token IDs (used for text input)
            positions: Position IDs (required by vLLM runner but not used)
            intermediate_tensors: For pipeline parallelism (not supported)
            inputs_embeds: Input embeddings (not used for TTS)
            **kwargs: Additional arguments including:
                - runtime_additional_information: Dict containing TTS params
                  (text, task_type, speaker, language, instruct, etc.)

        Returns:
            OmniOutput: Contains multimodal outputs with audio tensors
        """
        # Extract TTS parameters from runtime_additional_information
        runtime_info_list = kwargs.get("runtime_additional_information", [{}])
        if not isinstance(runtime_info_list, list):
            runtime_info_list = [runtime_info_list]
        if len(runtime_info_list) == 0:
            runtime_info_list = [{}]

        texts: list[str] = []
        task_types: list[str] = []
        speakers: list[str] = []
        languages: list[str] = []
        instructs: list[str] = []
        extra_kwargs: dict[str, list[Any]] = {}

        for runtime_info in runtime_info_list:
            info = dict(runtime_info)
            texts.append(self._extract_param(info, "text", ""))
            task_types.append(
                self._extract_param(info, "task_type", self.task_type))
            speakers.append(self._extract_param(info, "speaker", ""))
            languages.append(self._extract_param(info, "language", "Auto"))
            instructs.append(self._extract_param(info, "instruct", ""))

            for key, value in info.items():
                if isinstance(value, list) and len(value) > 0:
                    value = value[0]
                extra_kwargs.setdefault(key, []).append(value)

        if len(set(task_types)) != 1:
            raise ValueError(
                "Batching multiple task_type values is not supported.")
        task_type = task_types[0]
        text = texts[0] if len(texts) == 1 else texts
        speaker = speakers[0] if len(speakers) == 1 else speakers
        language = languages[0] if len(languages) == 1 else languages
        instruct = instructs[0] if len(instructs) == 1 else instructs
        extra_kwargs = {
            key: (values[0] if len(values) == 1 else values)
            for key, values in extra_kwargs.items()
        }

        # Call the appropriate generation method based on task_type
        if task_type == "CustomVoice":
            result = self.tts_wrapper.generate_custom_voice(
                text,
                speaker=speaker,
                language=language,
                instruct=instruct,
                **extra_kwargs,
            )
        elif task_type == "VoiceDesign":
            result = self.tts_wrapper.generate_voice_design(
                text,
                instruct=instruct,
                language=language,
                **extra_kwargs,
            )
        elif task_type == "Base":
            result = self.tts_wrapper.generate_voice_clone(
                text,
                language=language,
                **extra_kwargs,
            )
        else:
            raise ValueError(
                f"Invalid task type: {task_type}. "
                f"Expected one of: CustomVoice, VoiceDesign, Base"
            )

        return self._make_omni_output(result)

    def _extract_param(
        self,
        info: dict,
        key: str,
        default: Any,
    ) -> Any:
        """Extract parameter from runtime info dict."""
        value = info.pop(key, [default])
        if isinstance(value, list) and len(value) > 0:
            return value[0]
        return value if value is not None else default

    def _make_omni_output(
        self,
        model_outputs: torch.Tensor | OmniOutput | tuple,
    ) -> OmniOutput:
        """
        Convert model outputs to OmniOutput format.

        Args:
            model_outputs: Can be:
                - OmniOutput: returned as-is
                - tuple of (audio_tensors, sample_rate)
                - torch.Tensor: wrapped directly

        Returns:
            OmniOutput with audio data in multimodal_outputs
        """
        if isinstance(model_outputs, OmniOutput):
            return model_outputs

        # Handle tuple format: (audio_tensors, sample_rate)
        if isinstance(model_outputs, tuple) and len(model_outputs) == 2:
            audio_tensors, sr = model_outputs
            if isinstance(audio_tensors, list) and len(audio_tensors) > 0:
                audio_list = []
                for audio_tensor in audio_tensors:
                    if isinstance(audio_tensor, np.ndarray):
                        audio_tensor = torch.from_numpy(audio_tensor).float()
                    elif not isinstance(audio_tensor, torch.Tensor):
                        audio_tensor = torch.tensor(audio_tensor,
                                                    dtype=torch.float32)
                    audio_list.append(audio_tensor)
                return OmniOutput(
                    text_hidden_states=None,
                    multimodal_outputs={
                        "audio": audio_list,
                        "sample_rate": sr,
                    },
                )

            if isinstance(audio_tensors, np.ndarray) and audio_tensors.ndim >= 2:
                audio_list = [
                    torch.from_numpy(audio_tensors[i]).float()
                    for i in range(audio_tensors.shape[0])
                ]
                return OmniOutput(
                    text_hidden_states=None,
                    multimodal_outputs={
                        "audio": audio_list,
                        "sample_rate": sr,
                    },
                )

        # Handle raw tensor
        if isinstance(model_outputs, torch.Tensor):
            return OmniOutput(
                text_hidden_states=None,
                multimodal_outputs={"audio": model_outputs},
            )

        raise ValueError(f"Unsupported model_outputs type: {type(model_outputs)}")

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
        """
        Create empty intermediate tensors for pipeline parallelism.

        Note: TTS models do not support pipeline parallelism due to their
        non-standard generation process. This method returns an empty dict
        to satisfy the vLLM interface.
        """
        return IntermediateTensors({})

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        """
        Get input embeddings from token IDs.

        For TTS models, this returns a dummy tensor since the actual
        embedding is handled internally by the TTS model.
        """
        # TTS models handle embeddings internally
        # Return dummy tensor to satisfy interface
        hidden_size = getattr(
            self.tts_wrapper.model.config,
            "hidden_size",
            1024,
        )
        return torch.zeros(
            (input_ids.shape[0], input_ids.shape[1], hidden_size),
            dtype=self.config.dtype or torch.bfloat16,
            device=input_ids.device,
        )

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        """
        Load weights into the model.

        For Qwen3-TTS, weights are loaded via HuggingFace's from_pretrained()
        during __init__. This method handles any additional weight loading
        that vLLM may perform.

        Args:
            weights: Iterable of (name, tensor) pairs

        Returns:
            Set of parameter names that were loaded
        """
        # The model is already loaded via from_pretrained in __init__
        # Here we just track which weights were provided
        loaded_params: set[str] = set()

        # Get the actual model's parameter dict for matching
        model_params = dict(self.tts_wrapper.model.named_parameters())

        for name, loaded_weight in weights:
            # Try to match and load weights if they exist in the model
            if name in model_params:
                param = model_params[name]
                if param.shape == loaded_weight.shape:
                    param.data.copy_(loaded_weight)
                    loaded_params.add(name)
                else:
                    logger.warning(
                        f"Shape mismatch for {name}: "
                        f"expected {param.shape}, got {loaded_weight.shape}"
                    )
            else:
                # Weight not in model, just track it
                loaded_params.add(name)

        return loaded_params

    def compute_logits(
        self,
        hidden_states: torch.Tensor | OmniOutput,
        sampling_metadata: Optional[SamplingMetadata] = None,
    ) -> Optional[torch.Tensor]:
        """
        Compute logits from hidden states.

        For TTS models, this returns None since they don't produce token
        logits. Instead, they generate audio directly.
        """
        return None

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        """
        Sample next tokens from logits.

        For TTS models, this returns None since they don't use
        autoregressive token sampling.
        """
        return None


class Qwen3TTSModel:
    """
    A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides:
      - from_pretrained() initialization via AutoModel/AutoProcessor
      - generation APIs for:
          * CustomVoice: generate_custom_voice()
          * VoiceDesign: generate_voice_design()
          * Base: generate_voice_clone() + create_voice_clone_prompt()
      - consistent output: (wavs: List[np.ndarray], sample_rate: int)

    Notes:
      - This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration`
      - Language / speaker validation is done via model methods:
          model.get_supported_languages(), model.get_supported_speakers()
    """

    def __init__(
        self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: dict[str, Any] | None = None
    ):
        self.model = model
        self.processor = processor
        self.generate_defaults = generate_defaults or {}

        self.device = getattr(model, "device", None)
        if self.device is None:
            try:
                self.device = next(model.parameters()).device
            except StopIteration:
                self.device = torch.device("cpu")

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        **kwargs: Any,
    ) -> "Qwen3TTSModel":
        """
        Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style.

        This method:
          1) Loads config via AutoConfig (so your side can register model_type -> config/model).
          2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged.
          3) Loads the processor via AutoProcessor.from_pretrained(model_path).
          4) Loads optional `generate_config.json` from the model directory/repo snapshot if present.

        Args:
            pretrained_model_name_or_path (str):
                HuggingFace repo id or local directory of the model.
            **kwargs:
                Forwarded as-is into `AutoModel.from_pretrained(...)`.
                Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2".

        Returns:
            Qwen3TTSModel:
                Wrapper instance containing `model`, `processor`, and generation defaults.
        """
        AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
        AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
        AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)

        model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
        if not isinstance(model, Qwen3TTSForConditionalGeneration):
            raise TypeError(f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. ")

        processor = AutoProcessor.from_pretrained(
            pretrained_model_name_or_path,
            fix_mistral_regex=True,
        )

        generate_defaults = model.generate_config
        return cls(model=model, processor=processor, generate_defaults=generate_defaults)

    def _supported_languages_set(self) -> set | None:
        langs = getattr(self.model, "get_supported_languages", None)
        if callable(langs):
            v = langs()
            if v is None:
                return None
            return set([str(x).lower() for x in v])
        return None

    def _supported_speakers_set(self) -> set | None:
        spks = getattr(self.model, "get_supported_speakers", None)
        if callable(spks):
            v = spks()
            if v is None:
                return None
            return set([str(x).lower() for x in v])
        return None

    def _validate_languages(self, languages: list[str]) -> None:
        """
        Validate that requested languages are supported by the model.

        Args:
            languages (List[str]): Language names for each sample.

        Raises:
            ValueError: If any language is not supported.
        """
        supported = self._supported_languages_set()
        if supported is None:
            return

        bad = []
        for lang in languages:
            if lang is None:
                bad.append(lang)
                continue
            if str(lang).lower() not in supported:
                bad.append(lang)
        if bad:
            raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}")

    def _validate_speakers(self, speakers: list[str | None]) -> None:
        """
        Validate that requested speakers are supported by the Instruct model.

        Args:
            speakers (List[Optional[str]]): Speaker names for each sample.

        Raises:
            ValueError: If any speaker is not supported.
        """
        supported = self._supported_speakers_set()
        if supported is None:
            return

        bad = []
        for spk in speakers:
            if spk is None or spk == "":
                continue
            if str(spk).lower() not in supported:
                bad.append(spk)
        if bad:
            raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}")

    def _is_probably_base64(self, s: str) -> bool:
        if s.startswith("data:audio"):
            return True
        if ("/" not in s and "\\" not in s) and len(s) > 256:
            return True
        return False

    def _is_url(self, s: str) -> bool:
        try:
            u = urlparse(s)
            return u.scheme in ("http", "https") and bool(u.netloc)
        except Exception:
            return False

    def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
        if "," in b64 and b64.strip().startswith("data:"):
            b64 = b64.split(",", 1)[1]
        return base64.b64decode(b64)

    def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]:
        if self._is_url(x):
            with urllib.request.urlopen(x) as resp:
                audio_bytes = resp.read()
            with io.BytesIO(audio_bytes) as f:
                audio, sr = sf.read(f, dtype="float32", always_2d=False)
        elif self._is_probably_base64(x):
            wav_bytes = self._decode_base64_to_wav_bytes(x)
            with io.BytesIO(wav_bytes) as f:
                audio, sr = sf.read(f, dtype="float32", always_2d=False)
        else:
            audio, sr = librosa.load(x, sr=None, mono=True)

        if audio.ndim > 1:
            audio = np.mean(audio, axis=-1)

        return audio.astype(np.float32), int(sr)

    def _normalize_audio_inputs(self, audios: AudioLike | list[AudioLike]) -> list[tuple[np.ndarray, int]]:
        """
        Normalize audio inputs into a list of (waveform, sr).

        Supported forms:
          - str: wav path / URL / base64 audio string
          - (np.ndarray, sr): waveform + sampling rate
          - list of the above

        Args:
            audios:
                Audio input(s).

        Returns:
            List[Tuple[np.ndarray, int]]:
                List of (float32 waveform, original sr).

        Raises:
            ValueError: If a numpy waveform is provided without sr.
        """
        if isinstance(audios, list):
            items = audios
        else:
            items = [audios]

        out: list[tuple[np.ndarray, int]] = []
        for a in items:
            if isinstance(a, str):
                out.append(self._load_audio_to_np(a))
            elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
                out.append((a[0].astype(np.float32), int(a[1])))
            elif isinstance(a, np.ndarray):
                raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
            else:
                raise TypeError(f"Unsupported audio input type: {type(a)}")
        for i, a in enumerate(out):
            if a[0].ndim > 1:
                a[0] = np.mean(a[0], axis=-1).astype(np.float32)
                out[i] = (a[0], a[1])
        return out

    def _ensure_list(self, x: MaybeList) -> list[Any]:
        return x if isinstance(x, list) else [x]

    def _build_assistant_text(self, text: str) -> str:
        return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"

    def _build_ref_text(self, text: str) -> str:
        return f"<|im_start|>assistant\n{text}<|im_end|>\n"

    def _build_instruct_text(self, instruct: str) -> str:
        return f"<|im_start|>user\n{instruct}<|im_end|>\n"

    def _tokenize_texts(self, texts: list[str]) -> list[torch.Tensor]:
        input_ids = []
        for text in texts:
            input = self.processor(text=text, return_tensors="pt", padding=True)
            input_id = input["input_ids"].to(self.device)
            input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
            input_ids.append(input_id)
        return input_ids

    def _merge_generate_kwargs(
        self,
        non_streaming_mode: bool | None = None,
        do_sample: bool | None = None,
        top_k: int | None = None,
        top_p: float | None = None,
        temperature: float | None = None,
        repetition_penalty: float | None = None,
        subtalker_dosample: bool | None = None,
        subtalker_top_k: int | None = None,
        subtalker_top_p: float | None = None,
        subtalker_temperature: float | None = None,
        max_new_tokens: int | None = None,
        **kwargs: Any,
    ) -> dict[str, Any]:
        """
        Merge user-provided generation arguments with defaults from `generate_config.json`.

        Rule:
          - If the user explicitly passes a value (not None), use it.
          - Otherwise, use the value from generate_config.json if present.
          - Otherwise, fall back to the hard defaults.

        Args:
            non_streaming_mode, do_sample, top_k, top_p, temperature, repetition_penalty,
            subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
                Common generation parameters.
            **kwargs:
                Other arguments forwarded to model.generate().

        Returns:
            Dict[str, Any]: Final kwargs to pass into model.generate().
        """
        hard_defaults = dict(
            non_streaming_mode=False,
            do_sample=True,
            top_k=50,
            top_p=1.0,
            temperature=0.9,
            repetition_penalty=1.05,
            subtalker_dosample=True,
            subtalker_top_k=50,
            subtalker_top_p=1.0,
            subtalker_temperature=0.9,
            max_new_tokens=2048,
        )

        def pick(name: str, user_val: Any) -> Any:
            if user_val is not None:
                return user_val
            if name in self.generate_defaults:
                return self.generate_defaults[name]
            return hard_defaults[name]

        merged = dict(kwargs)
        merged.update(
            non_streaming_mode=pick("non_streaming_mode", non_streaming_mode),
            do_sample=pick("do_sample", do_sample),
            top_k=pick("top_k", top_k),
            top_p=pick("top_p", top_p),
            temperature=pick("temperature", temperature),
            repetition_penalty=pick("repetition_penalty", repetition_penalty),
            subtalker_dosample=pick("subtalker_dosample", subtalker_dosample),
            subtalker_top_k=pick("subtalker_top_k", subtalker_top_k),
            subtalker_top_p=pick("subtalker_top_p", subtalker_top_p),
            subtalker_temperature=pick("subtalker_temperature", subtalker_temperature),
            max_new_tokens=pick("max_new_tokens", max_new_tokens),
        )
        return merged

    # voice clone model
    @torch.inference_mode()
    def create_voice_clone_prompt(
        self,
        ref_audio: AudioLike | list[AudioLike],
        ref_text: str | list[str | None] | None = None,
        x_vector_only_mode: bool | list[bool] = False,
    ) -> list[VoiceClonePromptItem]:
        """
        Build voice-clone prompt items from reference audio (and optionally reference text) using Base model.

        Modes:
          - x_vector_only_mode=True:
              Only speaker embedding is used to clone voice; ref_text/ref_code are ignored.
              This is mutually exclusive with ICL.
          - x_vector_only_mode=False:
              ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required,
              because the model continues/conditions on the reference text + reference speech codes.

        Batch behavior:
          - ref_audio can be a single item or a list.
          - ref_text and x_vector_only_mode can be scalars or lists.
          - If any of them are lists with length > 1, lengths must match.

        Audio input:
          - str: local wav path / URL / base64
          - (np.ndarray, sr): waveform + sampling rate

        Args:
            ref_audio:
                Reference audio(s) used to extract:
                  - ref_code via `model.speech_tokenizer.encode(...)`
                  - ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k)
            ref_text:
                Reference transcript(s). Required when x_vector_only_mode=False (ICL mode).
            x_vector_only_mode:
                Whether to use speaker embedding only. If False, ICL mode will be used.

        Returns:
            List[VoiceClonePromptItem]:
                List of prompt items that can be converted into `voice_clone_prompt` dict.

        Raises:
            ValueError:
                - If x_vector_only_mode=False but ref_text is missing.
                - If batch lengths mismatch.
        """
        if self.model.tts_model_type != "base":
            raise ValueError(
                f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
                f"tts_model_size: {self.model.tts_model_size}\n"
                f"tts_model_type: {self.model.tts_model_type}\n"
                "does not support create_voice_clone_prompt, Please check Model Card or Readme for more details."
            )

        ref_audio_list = self._ensure_list(ref_audio)
        ref_text_list = (
            self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list))
        )
        xvec_list = (
            self._ensure_list(x_vector_only_mode)
            if isinstance(x_vector_only_mode, list)
            else ([x_vector_only_mode] * len(ref_audio_list))
        )

        if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list):
            raise ValueError(
                f"Batch size mismatch: ref_audio={len(ref_audio_list)}, "
                f"ref_text={len(ref_text_list)}, "
                f"x_vector_only_mode={len(xvec_list)}"
            )

        normalized = self._normalize_audio_inputs(ref_audio_list)

        ref_wavs_for_code: list[np.ndarray] = []
        ref_sr_for_code: list[int] = []
        for wav, sr in normalized:
            ref_wavs_for_code.append(wav)
            ref_sr_for_code.append(sr)

        if len(set(ref_sr_for_code)) == 1:
            enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0])
            ref_codes = enc.audio_codes
        else:
            ref_codes = []
            for wav, sr in normalized:
                ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0])

        items: list[VoiceClonePromptItem] = []
        for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)):
            if not xvec_only:
                if rtext is None or rtext == "":
                    rtext = "For profile run"
                    logger.warning(
                        f"ref_text is required when x_vector_only_mode=False (ICL mode). "
                        f"Bad index={i}. Please check if it is profile run or "
                        f"you missed to provide ref_text."
                    )
                    # raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}")

            wav_resample = wav
            if sr != self.model.speaker_encoder_sample_rate:
                wav_resample = librosa.resample(
                    y=wav_resample.astype(np.float32), orig_sr=int(sr), target_sr=self.model.speaker_encoder_sample_rate
                )

            spk_emb = self.model.extract_speaker_embedding(
                audio=wav_resample, sr=self.model.speaker_encoder_sample_rate
            )

            items.append(
                VoiceClonePromptItem(
                    ref_code=None if xvec_only else code,
                    ref_spk_embedding=spk_emb,
                    x_vector_only_mode=bool(xvec_only),
                    icl_mode=bool(not xvec_only),
                    ref_text=rtext,
                )
            )
        return items

    def _prompt_items_to_voice_clone_prompt(self, items: list[VoiceClonePromptItem]) -> dict[str, Any]:
        return dict(
            ref_code=[it.ref_code for it in items],
            ref_spk_embedding=[it.ref_spk_embedding for it in items],
            x_vector_only_mode=[it.x_vector_only_mode for it in items],
            icl_mode=[it.icl_mode for it in items],
        )

    # voice clone model
    @torch.no_grad()
    def generate_voice_clone(
        self,
        text: str | list[str],
        language: str | list[str] = None,
        ref_audio: AudioLike | list[AudioLike] | None = None,
        ref_text: str | list[str | None] | None = None,
        x_vector_only_mode: bool | list[bool] = False,
        voice_clone_prompt: dict[str, Any] | list[VoiceClonePromptItem] | None = None,
        **kwargs: Any,
    ) -> tuple[list[np.ndarray], int]:
        """
        Voice clone speech using the Base model.

        You can provide either:
          - (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR
          - `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR
          - a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`.

        `ref_audio` Supported forms:
        - str: wav path / URL / base64 audio string
        - (np.ndarray, sr): waveform + sampling rate
        - list of the above

        Input flexibility:
          - text/language can be scalar or list.
          - prompt can be single or batch.
          - If batch mode (len(text)>1), lengths must match.

        Args:
            text:
                Text(s) to synthesize.
            language:
                Language(s) for each sample.
            ref_audio:
                Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided.
            ref_text:
                Reference text(s) used for ICL mode (required when x_vector_only_mode=False).
            x_vector_only_mode:
                If True, only speaker embedding is used (ignores ref_text/ref_code).
                If False, ICL mode is used automatically.
            voice_clone_prompt:
                list[VoiceClonePromptItem] from `create_voice_clone_prompt`.
            **kwargs:
                Additional generation options. Common keys include `non_streaming_mode`, `do_sample`, `top_k`, `top_p`,
                `temperature`, `repetition_penalty`, `subtalker_dosample`, `subtalker_top_k`, `subtalker_top_p`,
                `subtalker_temperature`, and `max_new_tokens`. Any other keyword arguments supported by HuggingFace
                Transformers `generate()` can also be passed and will be forwarded to
                `Qwen3TTSForConditionalGeneration.generate(...)`.

        Returns:
            Tuple[List[np.ndarray], int]:
                (wavs, sample_rate)

        Raises:
            ValueError:
                If batch sizes mismatch or required prompt inputs are missing.
        """
        if self.model.tts_model_type != "base":
            raise ValueError(
                f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
                f"tts_model_size: {self.model.tts_model_size}\n"
                f"tts_model_type: {self.model.tts_model_type}\n"
                "does not support generate_voice_clone, Please check Model Card or Readme for more details."
            )

        texts = self._ensure_list(text)
        languages = (
            self._ensure_list(language)
            if isinstance(language, list)
            else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
        )
        if len(languages) == 1 and len(texts) > 1:
            languages = languages * len(texts)
        if len(texts) != len(languages):
            raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}")

        self._validate_languages(languages)

        if voice_clone_prompt is None:
            if ref_audio is None:
                # For profile run
                sample_rate = int(self.model.speaker_encoder_sample_rate)
                # Use a 1-second silent clip to satisfy padding requirements.
                ref_audio = (np.zeros(sample_rate, dtype=np.float32), sample_rate)
                logger.warning(
                    "ref_audio is not provided. Using a 1-second silent clip "
                    "to satisfy padding requirements. Please check if it is "
                    "profile run or you missed to provide ref_audio."
                )
            prompt_items = self.create_voice_clone_prompt(
                ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode
            )
            if len(prompt_items) == 1 and len(texts) > 1:
                prompt_items = prompt_items * len(texts)
            if len(prompt_items) != len(texts):
                raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
            voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
            ref_texts_for_ids = [it.ref_text for it in prompt_items]
        else:
            if isinstance(voice_clone_prompt, list):
                prompt_items = voice_clone_prompt
                if len(prompt_items) == 1 and len(texts) > 1:
                    prompt_items = prompt_items * len(texts)
                if len(prompt_items) != len(texts):
                    raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
                voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
                ref_texts_for_ids = [it.ref_text for it in prompt_items]
            else:
                voice_clone_prompt_dict = voice_clone_prompt
                ref_texts_for_ids = None

        input_texts = [self._build_assistant_text(t) for t in texts]
        input_ids = self._tokenize_texts(input_texts)

        ref_ids = None
        if ref_texts_for_ids is not None:
            ref_ids = []
            for i, rt in enumerate(ref_texts_for_ids):
                if rt is None or rt == "":
                    ref_ids.append(None)
                else:
                    ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0]
                    ref_ids.append(ref_tok)

        gen_kwargs = self._merge_generate_kwargs(**kwargs)

        talker_codes_list, _ = self.model.generate(
            input_ids=input_ids,
            ref_ids=ref_ids,
            voice_clone_prompt=voice_clone_prompt_dict,
            languages=languages,
            **gen_kwargs,
        )

        codes_for_decode = []
        for i, codes in enumerate(talker_codes_list):
            ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
            if ref_code_list is not None and ref_code_list[i] is not None:
                codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0))
            else:
                codes_for_decode.append(codes)

        wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode])

        wavs_out: list[np.ndarray] = []
        for i, wav in enumerate(wavs_all):
            ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
            if ref_code_list is not None and ref_code_list[i] is not None:
                ref_len = int(ref_code_list[i].shape[0])
                total_len = int(codes_for_decode[i].shape[0])
                cut = int(ref_len / max(total_len, 1) * wav.shape[0])
                wavs_out.append(wav[cut:])
            else:
                wavs_out.append(wav)

        return wavs_out, fs

    # voice design model
    @torch.no_grad()
    def generate_voice_design(
        self,
        text: str | list[str],
        instruct: str | list[str],
        language: str | list[str] = None,
        **kwargs: Any,
    ) -> tuple[list[np.ndarray], int]:
        """
        Generate speech with the VoiceDesign model using natural-language style instructions.

        Args:
            text:
                Text(s) to synthesize.
            language:
                Language(s) for each sample.
            instruct:
                Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction).
            **kwargs:
                Additional generation options. Common keys include `non_streaming_mode`, `do_sample`, `top_k`, `top_p`,
                `temperature`, `repetition_penalty`, `subtalker_dosample`, `subtalker_top_k`, `subtalker_top_p`,
                `subtalker_temperature`, and `max_new_tokens`. Any other keyword arguments supported by HuggingFace
                Transformers `generate()` can also be passed and will be forwarded to
                `Qwen3TTSForConditionalGeneration.generate(...)`.

        Returns:
            Tuple[List[np.ndarray], int]:
                (wavs, sample_rate)
        """
        if self.model.tts_model_type != "voice_design":
            raise ValueError(
                f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
                f"tts_model_size: {self.model.tts_model_size}\n"
                f"tts_model_type: {self.model.tts_model_type}\n"
                "does not support generate_voice_design, Please check Model Card or Readme for more details."
            )

        texts = self._ensure_list(text)
        languages = (
            self._ensure_list(language)
            if isinstance(language, list)
            else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
        )
        instructs = self._ensure_list(instruct)

        if len(languages) == 1 and len(texts) > 1:
            languages = languages * len(texts)
        if len(instructs) == 1 and len(texts) > 1:
            instructs = instructs * len(texts)

        if not (len(texts) == len(languages) == len(instructs)):
            raise ValueError(
                f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}"
            )

        self._validate_languages(languages)

        input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])

        instruct_ids: list[torch.Tensor | None] = []
        for ins in instructs:
            if ins is None or ins == "":
                instruct_ids.append(None)
            else:
                instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])

        gen_kwargs = self._merge_generate_kwargs(**kwargs)

        talker_codes_list, _ = self.model.generate(
            input_ids=input_ids,
            instruct_ids=instruct_ids,
            languages=languages,
            **gen_kwargs,
        )

        wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
        return wavs, fs

    # custom voice model
    @torch.no_grad()
    def generate_custom_voice(
        self,
        text: str | list[str],
        speaker: str | list[str],
        language: str | list[str] = None,
        instruct: str | list[str] | None = None,
        **kwargs: Any,
    ) -> tuple[list[np.ndarray], int]:
        """
        Generate speech with the CustomVoice model using a predefined speaker id,
        optionally controlled by instruction text.

        Args:
            text:
                Text(s) to synthesize.
            language:
                Language(s) for each sample.
            speaker:
                Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive).
            instruct:
                Optional instruction(s). If None, treated as empty (no instruction).
            **kwargs:
                Additional generation options. Common keys include `non_streaming_mode`, `do_sample`, `top_k`, `top_p`,
                `temperature`, `repetition_penalty`, `subtalker_dosample`, `subtalker_top_k`, `subtalker_top_p`,
                `subtalker_temperature`, and `max_new_tokens`. Any other keyword arguments supported by HuggingFace
                Transformers `generate()` can also be passed and will be forwarded to
                `Qwen3TTSForConditionalGeneration.generate(...)`.

        Returns:
            Tuple[List[np.ndarray], int]:
                (wavs, sample_rate)

        Raises:
            ValueError:
                If any speaker/language is unsupported or batch sizes mismatch.
        """
        if self.model.tts_model_type != "custom_voice":
            raise ValueError(
                f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
                f"tts_model_size: {self.model.tts_model_size}\n"
                f"tts_model_type: {self.model.tts_model_type}\n"
                "does not support generate_custom_voice, Please check Model Card or Readme for more details."
            )

        texts = self._ensure_list(text)
        languages = (
            self._ensure_list(language)
            if isinstance(language, list)
            else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
        )
        speakers = self._ensure_list(speaker)
        if self.model.tts_model_size in "0b6":  # for 0b6 model, instruct is not supported
            instruct = None
        instructs = (
            self._ensure_list(instruct)
            if isinstance(instruct, list)
            else ([instruct] * len(texts) if instruct is not None else [""] * len(texts))
        )

        if len(languages) == 1 and len(texts) > 1:
            languages = languages * len(texts)
        if len(speakers) == 1 and len(texts) > 1:
            speakers = speakers * len(texts)
        if len(instructs) == 1 and len(texts) > 1:
            instructs = instructs * len(texts)

        if not (len(texts) == len(languages) == len(speakers) == len(instructs)):
            raise ValueError(
                f"Batch size mismatch: text={len(texts)}, "
                f"language={len(languages)}, speaker={len(speakers)}, "
                f"instruct={len(instructs)}"
            )

        self._validate_languages(languages)
        self._validate_speakers(speakers)

        input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])

        instruct_ids: list[torch.Tensor | None] = []
        for ins in instructs:
            if ins is None or ins == "":
                instruct_ids.append(None)
            else:
                instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])

        gen_kwargs = self._merge_generate_kwargs(**kwargs)

        talker_codes_list, _ = self.model.generate(
            input_ids=input_ids,
            instruct_ids=instruct_ids,
            languages=languages,
            speakers=speakers,
            **gen_kwargs,
        )

        wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
        return wavs, fs

    def get_supported_speakers(self) -> list[str] | None:
        """
        List supported speaker names for the current model.

        This is a convenience wrapper around `model.get_supported_speakers()`.
        If the underlying model does not expose speaker constraints (returns None),
        this method also returns None.

        Returns:
            Optional[List[str]]:
                - A sorted list of supported speaker names (lowercased), if available.
                - None if the model does not provide supported speakers.
        """
        supported = self._supported_speakers_set()
        if supported is None:
            return None
        return sorted(supported)

    def get_supported_languages(self) -> list[str] | None:
        """
        List supported language names for the current model.

        This is a convenience wrapper around `model.get_supported_languages()`.
        If the underlying model does not expose language constraints (returns None),
        this method also returns None.

        Returns:
            Optional[List[str]]:
                - A sorted list of supported language names (lowercased), if available.
                - None if the model does not provide supported languages.
        """
        supported = self._supported_languages_set()
        if supported is None:
            return None
        return sorted(supported)


# =============================================================================
# vLLM-Optimized TTS Model (Phase 1 & 2)
# =============================================================================

class Qwen3TTSModelForGenerationV2(nn.Module, SupportsV0Only):
    """
    vLLM-optimized wrapper for Qwen3-TTS model.

    This version uses the vLLM-native Talker model with:
    - PagedAttention for efficient KV cache management
    - Tensor parallelism support
    - Optimized attention backends (FlashAttention, etc.)

    The Speech Tokenizer and other components remain HuggingFace-based
    since they only run once per generation.

    Note: This model only supports V0 vLLM engine due to its non-standard
    output format (audio instead of logits).
    """

    supports_v0_only: ClassVar[Literal[True]] = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        from .qwen3_tts_talker import Qwen3TTSTalkerForCausalLM

        model_config = vllm_config.model_config
        model_path = model_config.model

        # Get dtype from config
        torch_dtype = model_config.dtype
        if torch_dtype is None:
            torch_dtype = torch.bfloat16

        # Determine device from vLLM platform
        from vllm.platforms import current_platform
        if current_platform.is_rocm() or current_platform.is_cuda():
            device = torch.device(current_platform.device_type)
        else:
            device = torch.device("cpu")

        self.device = device
        self.dtype = torch_dtype

        # Load the HuggingFace config to get talker_config
        hf_config = model_config.hf_config
        if hasattr(hf_config, 'talker_config'):
            talker_config = hf_config.talker_config
        else:
            # Fallback: load from model path
            from transformers import AutoConfig
            AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
            full_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
            talker_config = full_config.talker_config

        # Create a modified vllm_config for the Talker
        # The Talker uses talker_config as its hf_config
        talker_vllm_config = vllm_config.with_hf_config(talker_config)

        # Initialize vLLM-native Talker model
        self.talker = Qwen3TTSTalkerForCausalLM(
            vllm_config=talker_vllm_config,
            prefix=maybe_prefix(prefix, "talker"),
        )

        # Load Speech Tokenizer (HuggingFace-based, runs once per generation)
        from .qwen3_tts_utils.qwen3_tts_tokenizer import Qwen3TTSTokenizer
        try:
            import os
            speech_tokenizer_path = os.path.join(model_path, "speech_tokenizer")
            if os.path.exists(speech_tokenizer_path):
                self.speech_tokenizer = Qwen3TTSTokenizer.from_pretrained(
                    speech_tokenizer_path,
                    torch_dtype=torch_dtype,
                )
            else:
                logger.warning(
                    f"Speech tokenizer not found at {speech_tokenizer_path}. "
                    "Audio decoding will not be available."
                )
                self.speech_tokenizer = None
        except Exception as e:
            logger.warning(f"Failed to load speech tokenizer: {e}")
            self.speech_tokenizer = None

        # Load processor for text tokenization
        try:
            AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)
            self.processor = AutoProcessor.from_pretrained(
                model_path,
                fix_mistral_regex=True,
            )
        except Exception as e:
            logger.warning(f"Failed to load processor: {e}")
            self.processor = None

        # Infer task type from model path
        self.task_type = self._infer_task_type(model_path)

        # Store configs
        self.vllm_config = vllm_config
        self.model_config = model_config
        self.talker_config = talker_config

        # Mark multimodal outputs
        self.have_multimodal_outputs = True

        logger.info(
            f"Initialized Qwen3TTSModelForGenerationV2 with vLLM-native Talker. "
            f"Task type: {self.task_type}"
        )

    def _infer_task_type(self, model_path: str) -> str:
        """Infer task type from model path."""
        path_lower = model_path.lower()
        if "customvoice" in path_lower or "custom_voice" in path_lower:
            return "CustomVoice"
        elif "voicedesign" in path_lower or "voice_design" in path_lower:
            return "VoiceDesign"
        elif "base" in path_lower:
            return "Base"
        suffix = model_path.rstrip("/").split("-")[-1]
        if suffix in ("CustomVoice", "VoiceDesign", "Base"):
            return suffix
        return "CustomVoice"

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        positions: torch.Tensor | None = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: Any,
    ) -> OmniOutput:
        """
        Forward pass for TTS generation.

        This method handles the full TTS pipeline:
        1. Process text input to embeddings
        2. Generate codec tokens using vLLM-native Talker
        3. Decode codec tokens to audio using Speech Tokenizer

        Args:
            input_ids: Input token IDs
            positions: Position indices
            intermediate_tensors: For pipeline parallelism
            inputs_embeds: Pre-computed embeddings
            **kwargs: Additional TTS parameters

        Returns:
            OmniOutput with audio data
        """
        # Extract TTS parameters
        runtime_info = kwargs.get("runtime_additional_information", [{}])
        if isinstance(runtime_info, list) and len(runtime_info) > 0:
            runtime_info = runtime_info[0]

        text = self._extract_param(runtime_info, "text", "")
        task_type = self._extract_param(runtime_info, "task_type", self.task_type)
        language = self._extract_param(runtime_info, "language", "Auto")
        instruct = self._extract_param(runtime_info, "instruct", "")

        # For now, use the Talker's forward pass directly
        # Full TTS pipeline integration requires more work
        if inputs_embeds is not None:
            hidden_states = self.talker.forward(
                input_ids=None,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
            )
        elif input_ids is not None:
            hidden_states = self.talker.forward(
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
            )
        else:
            # No input provided, return empty output
            return OmniOutput(
                text_hidden_states=None,
                multimodal_outputs={"audio": None, "sample_rate": 24000},
            )

        # Return hidden states wrapped in OmniOutput
        # Full audio generation requires the complete TTS pipeline
        return OmniOutput(
            text_hidden_states=hidden_states,
            multimodal_outputs={"audio": None, "sample_rate": 24000},
        )

    def _extract_param(self, info: dict, key: str, default: Any) -> Any:
        """Extract parameter from runtime info dict."""
        value = info.pop(key, [default])
        if isinstance(value, list) and len(value) > 0:
            return value[0]
        return value if value is not None else default

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> IntermediateTensors:
        """Create empty intermediate tensors for pipeline parallelism."""
        return self.talker.make_empty_intermediate_tensors

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Get input embeddings from token IDs."""
        return self.talker.get_input_embeddings(input_ids)

    def compute_logits(
        self,
        hidden_states: torch.Tensor | OmniOutput,
        sampling_metadata: Optional[SamplingMetadata] = None,
    ) -> Optional[torch.Tensor]:
        """Compute logits from hidden states."""
        if isinstance(hidden_states, OmniOutput):
            hidden_states = hidden_states.text_hidden_states
        if hidden_states is None:
            return None
        return self.talker.compute_logits(hidden_states, sampling_metadata)

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        """Sample next tokens from logits."""
        # TTS models use custom sampling, return None for now
        return None

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """
        Load weights from checkpoint.

        Handles weight mapping for:
        - Talker model weights (with prefix remapping)
        - Fused gate_up_proj from separate gate_proj and up_proj
        """
        loaded_params: set[str] = set()

        # Collect weights for the Talker
        talker_weights = []
        for name, weight in weights:
            # Remap weight names from HuggingFace format
            # HF: talker.model.layers.X.* -> vLLM: model.layers.X.*
            if name.startswith("talker."):
                new_name = name[len("talker."):]
                talker_weights.append((new_name, weight))
            else:
                talker_weights.append((name, weight))

        # Load weights into Talker
        loaded = self.talker.load_weights(iter(talker_weights))
        loaded_params.update(loaded)

        return loaded_params
