Unverified Commit d39b8daf authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Feature] Add Qwen3-ForcedAligner support via token classification pooling (#35367)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
parent fafca38a
......@@ -29,6 +29,12 @@ Offline: [examples/pooling/token_classify/ner_offline.py](../../../examples/pool
Online: [examples/pooling/token_classify/ner_online.py](../../../examples/pooling/token_classify/ner_online.py)
### Forced Alignment
Forced alignment takes audio and reference text as input and produces word-level timestamps.
Offline: [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py)
### Sparse retrieval (lexical matching)
The BAAI/bge-m3 model leverages token classification for sparse retrieval. For more information, see [this page](specific_models.md#baaibge-m3).
......@@ -49,6 +55,19 @@ The BAAI/bge-m3 model leverages token classification for sparse retrieval. For m
If your model is not in the above list, we will try to automatically convert the model using
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
### Multimodal Models
!!! note
For more information about multimodal models inputs, see [this page](../supported_models.md#list-of-multimodal-language-models).
| Architecture | Models | Inputs | Example HF Models | [LoRA](../../features/lora.md) | [PP](../../serving/parallelism_scaling.md) |
| --------------------------------------------- | ------------------- | ----------------- | ------------------------------------------ | ------------------------------ | ------------------------------------------ |
| `Qwen3ASRForcedAlignerForTokenClassification` | Qwen3-ForcedAligner | T + A<sup>+</sup> | `Qwen/Qwen3-ForcedAligner-0.6B` (see note) | | ✅︎ |
!!! note
Forced alignment usage requires `--hf-overrides '{"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]}'`.
Please refer to [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py).
### As Reward Models
Using token classification models as reward models. For details on reward models, see [Reward Models](reward.md).
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from Qwen3-ForcedAligner inference:
# https://github.com/QwenLM/Qwen3-ASR
"""
Offline forced alignment example using Qwen3-ForcedAligner-0.6B.
Forced alignment takes audio and reference text as input and produces
word-level timestamps. The model predicts a time bin at each <timestamp>
token position; multiplying by ``timestamp_segment_time`` gives milliseconds.
Usage::
python forced_alignment_offline.py \
--model Qwen/Qwen3-ForcedAligner-0.6B
"""
from argparse import Namespace
import numpy as np
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
parser.set_defaults(
model="Qwen/Qwen3-ForcedAligner-0.6B",
runner="pooling",
enforce_eager=True,
hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]},
)
return parser.parse_args()
def build_prompt(words: list[str]) -> str:
"""Build the forced alignment prompt from a word list.
Format: <|audio_start|><|audio_pad|><|audio_end|>
word1<timestamp><timestamp>word2<timestamp><timestamp>...
"""
body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"
def main(args: Namespace):
llm = LLM(**vars(args))
config = llm.llm_engine.vllm_config.model_config.hf_config
timestamp_token_id = config.timestamp_token_id
timestamp_segment_time = config.timestamp_segment_time
# Example: align these words against a 5-second audio clip
words = ["Hello", "world"]
prompt = build_prompt(words)
# Use a 5-second silent audio as placeholder (replace with real audio)
sample_rate = 16000
audio = np.zeros(sample_rate * 5, dtype=np.float32)
outputs = llm.encode(
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
pooling_task="token_classify",
)
for output in outputs:
logits = output.outputs.data # [num_tokens, classify_num]
predictions = logits.argmax(dim=-1)
token_ids = output.prompt_token_ids
# Extract timestamps at <timestamp> positions
ts_predictions = [
pred.item() * timestamp_segment_time
for tid, pred in zip(token_ids, predictions)
if tid == timestamp_token_id
]
# Pair up start/end times per word
for i, word in enumerate(words):
start_ms = ts_predictions[i * 2]
end_ms = ts_predictions[i * 2 + 1]
print(f"{word:15s} {start_ms / 1000:.3f}s - {end_ms / 1000:.3f}s")
if __name__ == "__main__":
args = parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
MODEL = "Qwen/Qwen3-ForcedAligner-0.6B"
CLASSIFY_NUM = 5000
TIMESTAMP_TOKEN_ID = 151705
def build_prompt(words: list[str]) -> str:
body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"
@pytest.mark.parametrize("model", [MODEL])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@torch.inference_mode()
def test_qwen3_forced_aligner(
vllm_runner,
model: str,
dtype: str,
) -> None:
words = ["Hello", "world"]
prompt = build_prompt(words)
# 5-second silent audio at 16kHz
audio = np.zeros(16000 * 5, dtype=np.float32)
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
enforce_eager=True,
max_model_len=512,
hf_overrides={
"architectures": [
"Qwen3ASRForcedAlignerForTokenClassification",
],
},
) as vllm_model:
outputs = vllm_model.llm.encode(
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
pooling_task="token_classify",
)
# Validate output structure
assert len(outputs) == 1
logits = outputs[0].outputs.data
assert logits.dim() == 2
assert logits.shape[1] == CLASSIFY_NUM
# Validate timestamp extraction
token_ids = outputs[0].prompt_token_ids
predictions = logits.argmax(dim=-1)
ts_indices = [i for i, t in enumerate(token_ids) if t == TIMESTAMP_TOKEN_ID]
# 2 words x 2 timestamps each (start + end) = 4
assert len(ts_indices) == 4
ts_preds = [predictions[i].item() for i in ts_indices]
assert all(p >= 0 for p in ts_preds)
# end >= start for each word
assert ts_preds[1] >= ts_preds[0] # Hello
assert ts_preds[3] >= ts_preds[2] # world
......@@ -1094,6 +1094,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
min_transformers_version="4.57",
hf_overrides={"architectures": ["Qwen3ASRRealtimeGeneration"]},
),
"Qwen3ASRForcedAlignerForTokenClassification": _HfExamplesInfo(
"Qwen/Qwen3-ForcedAligner-0.6B",
max_model_len=4096,
min_transformers_version="4.57",
hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]},
),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo(
"Skywork/Skywork-R1V-38B", trust_remote_code=True
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3-ASR ForcedAligner model (token classification)."""
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.models.interfaces_base import default_pooling_type
from vllm.model_executor.models.qwen3_asr import (
Qwen3ASRDummyInputsBuilder,
Qwen3ASRForConditionalGeneration,
Qwen3ASRMultiModalProcessor,
Qwen3ASRProcessingInfo,
)
from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
@default_pooling_type(tok_pooling_type="ALL")
@MULTIMODAL_REGISTRY.register_processor(
Qwen3ASRMultiModalProcessor,
info=Qwen3ASRProcessingInfo,
dummy_inputs=Qwen3ASRDummyInputsBuilder,
)
class Qwen3ASRForcedAlignerForTokenClassification(
Qwen3ASRForConditionalGeneration,
):
"""Qwen3-ASR Forced Aligner model for per-token timestamp classification.
This model shares the audio tower and language model backbone with
Qwen3-ASR, but replaces the LM head with a classification head that
predicts time bins at ``<timestamp>`` token positions.
Usage::
llm = LLM(
model="Qwen/Qwen3-ForcedAligner-0.6B",
runner="pooling",
hf_overrides={
"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]
},
)
outputs = llm.encode(
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
pooling_task="token_classify",
)
"""
is_pooling_model = True
# Map thinker.lm_head -> classifier (not language_model.lm_head)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"thinker.lm_head.": "classifier.",
"thinker.model.": "language_model.model.",
"thinker.": "",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
thinker_config = config.thinker_config
# Remove the unused generation head created by the base class;
# the forced aligner uses a classifier head instead.
self.language_model.lm_head = None
self.language_model.logits_processor = None
self.classify_num = thinker_config.classify_num
# Classification head replaces lm_head for time-bin prediction.
# Use model dtype (not head_dtype which defaults to float32 for
# pooling models) to match the hidden state dtype.
self.classifier = nn.Linear(
thinker_config.text_config.hidden_size,
self.classify_num,
bias=False,
dtype=vllm_config.model_config.dtype,
)
# Token-level pooler to split per-token logits per request
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = pooler_for_token_classify(pooler_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None:
inputs_embeds = None
# Run through language model backbone (transformer layers only)
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
# Apply classification head -> [num_tokens, classify_num]
return self.classifier(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["talker.", "code2wav."],
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......@@ -292,6 +292,10 @@ _TOKEN_CLASSIFICATION_MODELS = {
"modernbert",
"ModernBertForTokenClassification",
),
"Qwen3ASRForcedAlignerForTokenClassification": (
"qwen3_asr_forced_aligner",
"Qwen3ASRForcedAlignerForTokenClassification",
),
}
_SEQUENCE_CLASSIFICATION_MODELS = {
......
......@@ -342,12 +342,14 @@ class Qwen3ASRThinkerConfig(PretrainedConfig):
audio_start_token_id=151647,
user_token_id=872,
initializer_range=0.02,
classify_num=None,
**kwargs,
):
super().__init__(**kwargs)
self.user_token_id = user_token_id
self.audio_start_token_id = audio_start_token_id
self.initializer_range = initializer_range
self.classify_num = classify_num
if isinstance(audio_config, dict):
audio_config = Qwen3ASRAudioEncoderConfig(**audio_config)
......@@ -406,6 +408,8 @@ class Qwen3ASRConfig(PretrainedConfig):
self,
thinker_config=None,
support_languages=None,
timestamp_token_id=None,
timestamp_segment_time=None,
**kwargs,
):
if thinker_config is None:
......@@ -416,6 +420,8 @@ class Qwen3ASRConfig(PretrainedConfig):
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
self.support_languages = support_languages
self.timestamp_token_id = timestamp_token_id
self.timestamp_segment_time = timestamp_segment_time
super().__init__(**kwargs)
def get_text_config(self, decoder=False) -> "PretrainedConfig":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment