"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "add055e151f32f89dab5932d25e5285b2fc823f1"
Unverified Commit c8aca0c9 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

Support parakeet as audio encoder for nemotron-nano-vl (#35100)


Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent b602e4f2
...@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import ( ...@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
) )
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
init_vllm_registered_model, init_vllm_registered_model,
...@@ -55,12 +56,14 @@ from vllm.multimodal.evs import ( ...@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
compute_retention_mask, compute_retention_mask,
) )
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
AudioItem,
MultiModalDataDict, MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
VideoItem, VideoItem,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
AudioProcessorItems,
ImageEmbeddingItems, ImageEmbeddingItems,
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
...@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely ...@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit # Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
class NanoNemotronVLAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- b: Number of audio clips
- t: Audio feature length
- f: Feature size (mel bins)
"""
type: Literal["audio_features"] = "audio_features"
input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
IMG_START = "<img>" IMG_START = "<img>"
IMG_END = "</img>" IMG_END = "</img>"
IMG_CONTEXT = "<image>" IMG_CONTEXT = "<image>"
AUDIO_START = "<so_start>"
AUDIO_END = "<so_end>"
AUDIO_CONTEXT = "<so_embedding>"
# Profiling # Profiling
# MAX_FRAMES = 16 # MAX_FRAMES = 16
...@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token self.video_token = video_token
self.video_pruning_rate = video_pruning_rate self.video_pruning_rate = video_pruning_rate
self.audio_extractor: ParakeetExtractor | None = None
raw_sound_config = getattr(config, "sound_config", None)
if raw_sound_config is not None:
self.audio_extractor = ParakeetExtractor(raw_sound_config)
# Pre-tokenize special tokens for video processing # Pre-tokenize special tokens for video processing
# to avoid repeated tokenization # to avoid repeated tokenization
self._img_start_token_ids = tokenizer.encode( self._img_start_token_ids = tokenizer.encode(
...@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text = [t.replace("<video>", video_repl_text, 1) for t in text] text = [t.replace("<video>", video_repl_text, 1) for t in text]
return text, video_inputs return text, video_inputs
def _preprocess_audio(
self,
text: list[str],
audios: list[npt.NDArray],
):
if len(audios) == 0:
return text, {}
assert self.audio_extractor is not None
extractor = self.audio_extractor
parts = [x for x in re.split(f"({re.escape(AUDIO_CONTEXT)})", text[0]) if x]
token_count = parts.count(AUDIO_CONTEXT)
if token_count != len(audios):
raise ValueError(
"Number of audio tokens in text does not match the number "
f"of audios (tokens={token_count}, audios={len(audios)})."
)
audio_index = 0
for idx, part in enumerate(parts):
if part == AUDIO_CONTEXT:
audio_repl = self.get_audio_repl(audios[audio_index])
parts[idx] = audio_repl.full
audio_index += 1
text = ["".join(parts)]
audio_inputs = extractor(
audios,
sampling_rate=extractor.sampling_rate,
return_tensors="pt",
)
input_audio_features = audio_inputs.input_features
feature_attention_mask = audio_inputs.attention_mask
audio_feature_lengths = feature_attention_mask.sum(dim=1)
audio_inputs = {
"input_audio_features": input_audio_features,
"feature_attention_mask": feature_attention_mask,
"audio_feature_lengths": audio_feature_lengths,
}
return text, audio_inputs
def __call__( def __call__(
self, self,
text: str | list[str] | None = None, text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None, images: Image.Image | list[Image.Image] | None = None,
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None, videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
audios: AudioItem | list[AudioItem] | None = None,
return_tensors: str | TensorType | None = None, return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
) -> BatchFeature: ) -> BatchFeature:
...@@ -964,8 +1034,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -964,8 +1034,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
if max_num_tiles is None: if max_num_tiles is None:
max_num_tiles = self.max_num_tiles max_num_tiles = self.max_num_tiles
text, images, videos = [ text, images, videos, audios = [
self._make_batch_input(x) for x in (text, images, videos) self._make_batch_input(x) for x in (text, images, videos, audios)
] ]
text, image_inputs = self._preprocess_image( text, image_inputs = self._preprocess_image(
...@@ -980,17 +1050,22 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -980,17 +1050,22 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
max_num_tiles=1, max_num_tiles=1,
) )
text, audio_inputs = self._preprocess_audio(
text=text,
audios=audios,
)
text_inputs = self.tokenizer(text, add_special_tokens=False) text_inputs = self.tokenizer(text, add_special_tokens=False)
combined_inputs = {**text_inputs, **video_inputs, **audio_inputs}
if self.dynamic_tiler is None: if self.dynamic_tiler is None:
batch = BatchFeature( batch = BatchFeature(
{**text_inputs, **video_inputs, **image_inputs}, {**combined_inputs, **image_inputs},
tensor_type=return_tensors, tensor_type=return_tensors,
) )
else: else:
batch = BatchFeature( batch = BatchFeature(combined_inputs, tensor_type=return_tensors)
{**text_inputs, **video_inputs}, tensor_type=return_tensors
)
# allow images to be exempt from the BatchFeature validation: # allow images to be exempt from the BatchFeature validation:
# We will .stack() them in _parse_and_validate_image_input # We will .stack() them in _parse_and_validate_image_input
batch.update(image_inputs) batch.update(image_inputs)
...@@ -1006,6 +1081,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -1006,6 +1081,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def get_audio_repl(
self,
audio: npt.NDArray,
) -> PromptUpdateDetails[str]:
assert self.audio_extractor is not None
num_tokens = self.audio_extractor.audio_token_count(len(audio))
repl_full = f"{AUDIO_START}{AUDIO_CONTEXT * num_tokens}{AUDIO_END}"
return PromptUpdateDetails.select_text(repl_full, AUDIO_CONTEXT)
@classmethod @classmethod
def get_video_repl( def get_video_repl(
cls, cls,
...@@ -1147,15 +1231,28 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): ...@@ -1147,15 +1231,28 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
def supports_video(self): def supports_video(self):
return self.get_hf_processor().supports_video return self.get_hf_processor().supports_video
@property
def audio_extractor(self) -> ParakeetExtractor | None:
return self.get_hf_processor().audio_extractor
def get_data_parser(self): def get_data_parser(self):
target_sr = None
target_channels = None
if extractor := self.audio_extractor:
target_sr = extractor.sampling_rate
target_channels = 1
return MultiModalDataParser( return MultiModalDataParser(
video_needs_metadata=True, video_needs_metadata=True,
target_sr=target_sr,
target_channels=target_channels,
expected_hidden_size=self._get_expected_hidden_size(), expected_hidden_size=self._get_expected_hidden_size(),
) )
def get_supported_mm_limits(self): def get_supported_mm_limits(self):
video_limit = {"video": None} if self.supports_video else {} video_limit = {"video": None} if self.supports_video else {}
return {**super().get_supported_mm_limits(), **video_limit} audio_limit = {"audio": None} if self.audio_extractor is not None else {}
return {**super().get_supported_mm_limits(), **video_limit, **audio_limit}
def get_video_token(self) -> str | None: def get_video_token(self) -> str | None:
return IMG_CONTEXT return IMG_CONTEXT
...@@ -1304,7 +1401,16 @@ class NanoNemotronVLMultiModalProcessor( ...@@ -1304,7 +1401,16 @@ class NanoNemotronVLMultiModalProcessor(
else: else:
video_fields = {} video_fields = {}
return image_fields | video_fields if self.info.audio_extractor is not None:
audio_fields = dict(
input_audio_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
)
else:
audio_fields = {}
return image_fields | video_fields | audio_fields
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -1373,6 +1479,20 @@ class NanoNemotronVLMultiModalProcessor( ...@@ -1373,6 +1479,20 @@ class NanoNemotronVLMultiModalProcessor(
), ),
] ]
def get_audio_replacement(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
return hf_processor.get_audio_repl(audios.get(item_idx))
if self.info.audio_extractor is not None:
prompt_repl = [
*prompt_repl,
PromptReplacement(
modality="audio",
target=AUDIO_CONTEXT,
replacement=get_audio_replacement,
),
]
return prompt_repl return prompt_repl
...@@ -1422,8 +1542,13 @@ class NanoNemotronVLDummyInputsBuilder( ...@@ -1422,8 +1542,13 @@ class NanoNemotronVLDummyInputsBuilder(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
num_audios = mm_counts.get("audio", 0)
return super().get_dummy_text(mm_counts) + "<video>" * num_videos return (
super().get_dummy_text(mm_counts)
+ "<video>" * num_videos
+ AUDIO_CONTEXT * num_audios
)
def _get_dummy_videos( def _get_dummy_videos(
self, self,
...@@ -1482,7 +1607,25 @@ class NanoNemotronVLDummyInputsBuilder( ...@@ -1482,7 +1607,25 @@ class NanoNemotronVLDummyInputsBuilder(
} }
else: else:
dummy_video = {} dummy_video = {}
return {**dummy_image, **dummy_video}
if extractor := self.info.audio_extractor:
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
dummy_audio = {
"audio": self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
else:
dummy_audio = {}
return {**dummy_image, **dummy_video, **dummy_audio}
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
...@@ -1499,12 +1642,15 @@ class NemotronH_Nano_VL_V2( ...@@ -1499,12 +1642,15 @@ class NemotronH_Nano_VL_V2(
return "<image>" return "<image>"
if modality.startswith("video"): if modality.startswith("video"):
return "<video>" return "<video>"
if modality.startswith("audio"):
return AUDIO_CONTEXT
return None return None
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config model_config = vllm_config.model_config
multimodal_config = vllm_config.model_config.multimodal_config config = model_config.hf_config
multimodal_config = model_config.multimodal_config
image_size = config.force_image_size image_size = config.force_image_size
patch_size = config.patch_size patch_size = config.patch_size
self.patch_size = patch_size self.patch_size = patch_size
...@@ -1523,10 +1669,12 @@ class NemotronH_Nano_VL_V2( ...@@ -1523,10 +1669,12 @@ class NemotronH_Nano_VL_V2(
hf_config=config.text_config, hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
) )
llm_dtype = self.language_model.config.dtype
with self._mark_tower_model(vllm_config, {"image", "video"}): assert isinstance(llm_dtype, torch.dtype)
self.llm_dtype = llm_dtype
with self._mark_tower_model(vllm_config, {"image", "video", "audio"}):
self.vision_model = self.get_vit_model_from_radio_config(config).to( self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.dtype llm_dtype
) )
# Construct the vision projection. # Construct the vision projection.
...@@ -1547,14 +1695,26 @@ class NemotronH_Nano_VL_V2( ...@@ -1547,14 +1695,26 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(), ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
) )
self.mlp1 = mlp1.to(self.language_model.config.dtype) self.mlp1 = mlp1.to(llm_dtype)
self.sound_encoder: ProjectedParakeet | None = None
if getattr(config, "sound_config", None) is not None:
logger.info_once(
"Found sound config, initializing sound encoder for Nemotron AVLM",
scope="global",
)
self.sound_encoder = ProjectedParakeet(
config.sound_config,
dtype=llm_dtype,
llm_hidden_size=llm_hidden_size,
max_model_len=model_config.max_model_len,
)
self.config = config self.config = config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
# Pre-tokenize special tokens for video processing # Pre-tokenize special tokens for video processing
# to avoid repeated tokenization # to avoid repeated tokenization
tokenizer = cached_tokenizer_from_config(vllm_config.model_config) tokenizer = cached_tokenizer_from_config(model_config)
self._img_start_token_ids = tokenizer.encode( self._img_start_token_ids = tokenizer.encode(
IMG_START, add_special_tokens=False IMG_START, add_special_tokens=False
) )
...@@ -1566,7 +1726,10 @@ class NemotronH_Nano_VL_V2( ...@@ -1566,7 +1726,10 @@ class NemotronH_Nano_VL_V2(
config config
) )
if self.dynamic_resolution: if self.dynamic_resolution:
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor") logger.info_once(
"Dynamic resolution is enabled for NanoNemotronVLProcessor",
scope="global",
)
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
...@@ -1780,6 +1943,51 @@ class NemotronH_Nano_VL_V2( ...@@ -1780,6 +1943,51 @@ class NemotronH_Nano_VL_V2(
return final_video_embeddings return final_video_embeddings
def _process_audio_input(
self, audio_input: NanoNemotronVLAudioFeatureInputs
) -> tuple[torch.Tensor, ...]:
assert self.sound_encoder is not None
input_audio_features = audio_input.input_audio_features
feature_attention_mask = audio_input.feature_attention_mask
target_device = next(self.sound_encoder.parameters()).device
# When cross-request batching combines audio clips with different
# time dimensions, _reduce_data returns a list instead of a stacked
# tensor. Pad to the max time dim and stack; the attention mask
# already marks valid positions so zero-padding is safe.
if isinstance(input_audio_features, list):
feature_sizes = [f.shape[-2] for f in input_audio_features]
max_t = max(feature_sizes)
padded_feats = [
torch.nn.functional.pad(feat, (0, 0, 0, max_t - feat_size))
for feat, feat_size in zip(
input_audio_features, feature_sizes, strict=True
)
]
padded_masks = [
torch.nn.functional.pad(mask, (0, max_t - mask.shape[-1]))
for mask in feature_attention_mask
]
input_audio_features = torch.stack(padded_feats)
feature_attention_mask = torch.stack(padded_masks)
input_audio_features = input_audio_features.to(
dtype=self.llm_dtype, device=target_device
)
feature_attention_mask = feature_attention_mask.to(device=target_device)
sound_embeds = self.sound_encoder(input_audio_features, feature_attention_mask)
valid_input_lens = feature_attention_mask.sum(dim=1)
valid_output_lens = self.sound_encoder.encoder._get_subsampling_output_length(
valid_input_lens
)
truncated_embeds = []
for i in range(sound_embeds.shape[0]):
valid_len = valid_output_lens[i].item()
truncated_embeds.append(sound_embeds[i, :valid_len])
return tuple(truncated_embeds)
def _create_final_video_embeddings( def _create_final_video_embeddings(
self, self,
video_embeddings: torch.Tensor, video_embeddings: torch.Tensor,
...@@ -1887,6 +2095,18 @@ class NemotronH_Nano_VL_V2( ...@@ -1887,6 +2095,18 @@ class NemotronH_Nano_VL_V2(
modalities["images"] = self._parse_and_validate_image_input(**kwargs) modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities: if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(**kwargs) modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
if (
input_key
in (
"input_audio_features",
"feature_attention_mask",
"audio_feature_lengths",
)
and "audios" not in modalities
):
modalities["audios"] = NanoNemotronVLAudioFeatureInputs(
**kwargs, validate=False
)
return modalities return modalities
...@@ -1917,6 +2137,10 @@ class NemotronH_Nano_VL_V2( ...@@ -1917,6 +2137,10 @@ class NemotronH_Nano_VL_V2(
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input) video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings += tuple(video_embeddings)
if modality == "audios":
audio_input = modalities["audios"]
audio_embeddings = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings return multimodal_embeddings
...@@ -1947,8 +2171,8 @@ class NemotronH_Nano_VL_V2( ...@@ -1947,8 +2171,8 @@ class NemotronH_Nano_VL_V2(
""" """
return MultiModelKeys.from_string_field( return MultiModelKeys.from_string_field(
language_model="language_model", language_model="language_model",
connector="mlp1", connector=["mlp1", "sound_encoder.projection"],
tower_model="vision_model", tower_model=["vision_model", "sound_encoder.encoder"],
) )
def compute_logits( def compute_logits(
...@@ -1969,9 +2193,13 @@ class NemotronH_Nano_VL_V2( ...@@ -1969,9 +2193,13 @@ class NemotronH_Nano_VL_V2(
def is_vision_weights(name: str) -> bool: def is_vision_weights(name: str) -> bool:
return name.startswith("vision_model.radio_model.") return name.startswith("vision_model.radio_model.")
def is_sound_weights(name: str) -> bool:
return name.startswith("sound")
# Separate weights by component # Separate weights by component
llm_weights = [] llm_weights = []
vision_weights = [] vision_weights = []
sound_weights = []
for name, w in weights: for name, w in weights:
if is_llm(name): if is_llm(name):
...@@ -1987,9 +2215,15 @@ class NemotronH_Nano_VL_V2( ...@@ -1987,9 +2215,15 @@ class NemotronH_Nano_VL_V2(
# Convert: vision_model.radio_model.* → radio_model.* # Convert: vision_model.radio_model.* → radio_model.*
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
vision_weights.append((hf_key, w)) vision_weights.append((hf_key, w))
elif is_sound_weights(name):
assert self.sound_encoder is not None
sound_weights.append((name, w))
self.language_model.load_weights(llm_weights) self.language_model.load_weights(llm_weights)
self.vision_model.load_weights(vision_weights) self.vision_model.load_weights(vision_weights)
if self.sound_encoder is not None:
assert len(sound_weights) > 0
self.sound_encoder.load_weights(sound_weights)
def print_architecture(self, detailed: bool = True, save_to_file: str = None): def print_architecture(self, detailed: bool = True, save_to_file: str = None):
""" """
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Modules below used for the audio encoder component in: models/nano_nemotron_vl.py
"""
from collections.abc import Iterable
from dataclasses import asdict
import numpy as np
import torch
import torch.nn as nn
from transformers import ParakeetEncoder as HFParakeetEncoder
from transformers import ParakeetFeatureExtractor, PretrainedConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
class ParakeetProjection(nn.Module):
def __init__(self, config: ParakeetConfig) -> None:
super().__init__()
sound_hidden_size = config.hidden_size
proj_hidden_size = config.projection_hidden_size
llm_hidden_size = config.llm_hidden_size
bias = config.projection_bias
self.norm = nn.LayerNorm(sound_hidden_size, eps=config.projection_eps)
self.linear1 = nn.Linear(sound_hidden_size, proj_hidden_size, bias=bias)
self.activation = ReLUSquaredActivation()
self.linear2 = nn.Linear(proj_hidden_size, llm_hidden_size, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
hidden_states = self.linear1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.linear2(hidden_states)
return hidden_states
class ProjectedParakeet(nn.Module):
def __init__(
self,
config: PretrainedConfig,
*,
dtype: torch.dtype,
llm_hidden_size: int,
max_model_len: int,
) -> None:
super().__init__()
self.config = ParakeetConfig.from_hf_config(
config, llm_hidden_size=llm_hidden_size, max_model_len=max_model_len
)
self.encoder = HFParakeetEncoder(self.config)
self.encoder = self.encoder.to(dtype)
self.projection = ParakeetProjection(self.config)
self.projection = self.projection.to(dtype)
def forward(
self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
outputs = self.encoder(
input_features=input_features, attention_mask=attention_mask
)
outputs = outputs.last_hidden_state
outputs = outputs.to(dtype=torch.bfloat16)
outputs = self.projection(outputs)
return outputs
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()
params_dict = dict(self.named_parameters())
buffers_dict = dict(self.named_buffers())
if isinstance(weights, dict):
weights_list = list(weights.items())
else:
weights_list = list(weights)
for name, weight in weights_list:
if name.startswith("sound_encoder.encoder.feature_extractor."):
# Feature extractor buffers are handled outside the encoder.
continue
if name.startswith("sound_encoder."):
target_name = name[len("sound_encoder.") :]
elif name.startswith("sound_projection."):
target_name = f"projection.{name[len('sound_projection.') :]}"
else:
continue
target = params_dict.get(target_name)
if target is None:
target = buffers_dict.get(target_name)
if target is None:
raise ValueError(f"Unknown weight: {name}")
weight_loader = getattr(target, "weight_loader", default_weight_loader)
with torch.no_grad():
weight_loader(target, weight)
loaded_params.add(target_name)
return loaded_params
class ParakeetExtractor(ParakeetFeatureExtractor):
def __init__(self, config: PretrainedConfig) -> None:
self.config = ExtractorConfig.from_hf_config(config)
super().__init__(**asdict(self.config))
self._clip_target_samples = int(
round(self.config.clip_duration_s * self.sampling_rate)
)
self._tail_min_samples = int(
round(self.config.clip_min_duration_s * self.sampling_rate)
)
def _normalize_audio_length(self, audio_len: int) -> int:
# Match mcore's compute_params() logic for clip/minduration handling.
target_len = max(audio_len, self._tail_min_samples)
tail_remainder = target_len % self._clip_target_samples
if 0 < tail_remainder < self._tail_min_samples:
padding = self._tail_min_samples - tail_remainder
target_len += padding
assert isinstance(target_len, int)
return target_len
def audio_token_count(self, audio_len: int) -> int:
audio_len = self._normalize_audio_length(audio_len)
num_frames = audio_len // self.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float)
)
return max(1, n_tokens.item())
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
padded = []
for p in raw_speech:
assert p.ndim == 1
audio_len = int(p.shape[0])
target_len = self._normalize_audio_length(audio_len)
p = np.pad(p, (0, target_len - audio_len))
padded.append(p)
return super().__call__(padded, *args, **kwargs)
def audio_length(self, audio_tokens: int) -> int:
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from transformers import ParakeetEncoderConfig, PretrainedConfig
class ParakeetConfig(ParakeetEncoderConfig):
llm_hidden_size: int
projection_hidden_size: int
projection_bias: bool
projection_eps: float = 1e-5
sampling_rate: int
@staticmethod
def from_hf_config(
config: PretrainedConfig, *, llm_hidden_size: int, max_model_len: int
) -> "ParakeetConfig":
assert isinstance(config, PretrainedConfig)
return ParakeetConfig(
**config.to_dict(),
scale_input=False,
attention_bias=False,
llm_hidden_size=llm_hidden_size,
max_position_embeddings=max_model_len
+ 1, # + 1 because it seems like max_model_len+1 can be passed
)
@dataclass(kw_only=True, frozen=True)
class ExtractorConfig:
feature_size: int
sampling_rate: int
subsampling_factor: int
subsampling_conv_kernel_size: int
subsampling_conv_stride: int
clip_duration_s: int = 30
clip_min_duration_s: float = 0.1
@staticmethod
def from_hf_config(config: PretrainedConfig) -> "ExtractorConfig":
assert isinstance(config, PretrainedConfig)
return ExtractorConfig(
feature_size=config.num_mel_bins,
sampling_rate=config.sampling_rate,
subsampling_factor=config.subsampling_factor,
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
subsampling_conv_stride=config.subsampling_conv_stride,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment