Unverified Commit ccee371e authored by Hyogeun Oh (오효근)'s avatar Hyogeun Oh (오효근) Committed by GitHub
Browse files

[Docs] Fix warnings in `mkdocs build` (continued) (#24092)


Signed-off-by: default avatarZerohertz <ohg3417@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent c0bd6a68
...@@ -755,7 +755,7 @@ class FusedMoE(CustomOp): ...@@ -755,7 +755,7 @@ class FusedMoE(CustomOp):
intermediate_size: Intermediate size of the experts intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure. quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer. enable_eplb: Whether to enable expert parallelism load balancer.
""" """
......
...@@ -420,9 +420,8 @@ def shuffle_weights( ...@@ -420,9 +420,8 @@ def shuffle_weights(
Args: Args:
*tensors: Variable number of torch.Tensor objects. *tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the layout: A pair of integers specifying the block sizes used to divide
block sizes used to divide the tensors during shuffling. the tensors during shuffling. Default is (16, 16).
Default is (16, 16).
Returns: Returns:
A Tuple of shuffled tensors. A Tuple of shuffled tensors.
......
...@@ -10,7 +10,7 @@ like uniform random routing. ...@@ -10,7 +10,7 @@ like uniform random routing.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy): ...@@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy):
distributions for testing different routing patterns. distributions for testing different routing patterns.
""" """
def __init__(self, distribution: str = "uniform", **distribution_params): def __init__(self,
distribution: str = "uniform",
**distribution_params: Any):
""" """
Initialize distribution-based routing. Initialize distribution-based routing.
...@@ -244,7 +246,7 @@ class RoutingSimulator: ...@@ -244,7 +246,7 @@ class RoutingSimulator:
cls._routing_strategies[name] = strategy cls._routing_strategies[name] = strategy
@classmethod @classmethod
def get_available_strategies(cls): def get_available_strategies(cls) -> list[str]:
""" """
Get list of available routing strategy names. Get list of available routing strategy names.
......
...@@ -202,7 +202,7 @@ class BitBLASLinearMethod(LinearMethodBase): ...@@ -202,7 +202,7 @@ class BitBLASLinearMethod(LinearMethodBase):
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ) -> None:
"""Creates quantized weights for use in linear operations. """Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing quantized The function initializes and returns a dictionary containing quantized
...@@ -211,7 +211,7 @@ class BitBLASLinearMethod(LinearMethodBase): ...@@ -211,7 +211,7 @@ class BitBLASLinearMethod(LinearMethodBase):
Args: Args:
input_size_per_partition: The size of the input partition. input_size_per_partition: The size of the input partition.
output_size_per_partition: The size of the output partition. output_partition_sizes: List of output partition sizes.
input_size: The total size of the input (unused). input_size: The total size of the input (unused).
output_size: The total size of the output (unused). output_size: The total size of the output (unused).
params_dtype: params_dtype:
...@@ -222,9 +222,9 @@ class BitBLASLinearMethod(LinearMethodBase): ...@@ -222,9 +222,9 @@ class BitBLASLinearMethod(LinearMethodBase):
scales ('scales'), and zeros ('zeros'). scales ('scales'), and zeros ('zeros').
Raises: Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the ValueError: If `params_dtype` is not `torch.float16` or if the input
input size per partition is not divisible by the group size in size per partition is not divisible by the group size
`quant_config`. in `quant_config`.
""" """
del input_size, output_size # Unused arguments. del input_size, output_size # Unused arguments.
weight_loader = extra_weight_attrs["weight_loader"] weight_loader = extra_weight_attrs["weight_loader"]
......
...@@ -265,9 +265,9 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): ...@@ -265,9 +265,9 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
scales ('scales'), and zeros ('zeros'). scales ('scales'), and zeros ('zeros').
Raises: Raises:
ValueError: If `params_dtype` is not `torch.float16` or ValueError: If `params_dtype` is not `torch.float16` or if the input
if the input size per partition is not divisible by the size per partition is not divisible by the group size
group size in `quant_config`. in `quant_config`.
""" """
if params_dtype != torch.float16: if params_dtype != torch.float16:
raise ValueError("Parameter data type must be torch.float16, " raise ValueError("Parameter data type must be torch.float16, "
......
...@@ -49,8 +49,8 @@ def choose_mp_linear_kernel( ...@@ -49,8 +49,8 @@ def choose_mp_linear_kernel(
config (MPLinearLayerConfig): Description of the linear layer to be config (MPLinearLayerConfig): Description of the linear layer to be
implemented. implemented.
compute_capability (Optional[int], optional): The compute capability of compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute the target device, if None uses `current_platform` to get
capability. Defaults to None. the compute capability. Defaults to None.
Raises: Raises:
ValueError: If no kernel can implement the given config. ValueError: If no kernel can implement the given config.
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import abc import abc
import math import math
from typing import Literal, Optional from typing import Any, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -131,31 +131,31 @@ class ConformerEncoderLayer(nn.Module): ...@@ -131,31 +131,31 @@ class ConformerEncoderLayer(nn.Module):
def __init__( def __init__(
self, self,
d_model=512, d_model: int = 512,
ext_pw_out_channel=0, ext_pw_out_channel: int = 0,
depthwise_seperable_out_channel=256, depthwise_seperable_out_channel: int = 256,
depthwise_multiplier=1, depthwise_multiplier: int = 1,
n_head=4, n_head: int = 4,
d_ffn=2048, d_ffn: int = 2048,
ext_pw_kernel_size=1, ext_pw_kernel_size: int = 1,
kernel_size=3, kernel_size: int = 3,
dropout_rate=0.1, dropout_rate: float = 0.1,
causal=False, causal: bool = False,
batch_norm=False, batch_norm: bool = False,
activation="relu", activation: str = "relu",
chunk_se=0, chunk_se: int = 0,
chunk_size=18, chunk_size: int = 18,
conv_activation="relu", conv_activation: str = "relu",
conv_glu_type="sigmoid", conv_glu_type: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
linear_glu_in_convm=False, linear_glu_in_convm: bool = False,
attention_inner_dim=-1, attention_inner_dim: int = -1,
attention_glu_type="swish", attention_glu_type: str = "swish",
activation_checkpointing="", activation_checkpointing: str = "",
export=False, export: bool = False,
use_pt_scaled_dot_product_attention=False, use_pt_scaled_dot_product_attention: bool = False,
attn_group_sizes: int = 1, attn_group_sizes: int = 1,
): ) -> None:
super().__init__() super().__init__()
self.feed_forward_in = FeedForward( self.feed_forward_in = FeedForward(
...@@ -209,24 +209,21 @@ class ConformerEncoderLayer(nn.Module): ...@@ -209,24 +209,21 @@ class ConformerEncoderLayer(nn.Module):
def forward( def forward(
self, self,
x, x: torch.Tensor,
pos_k, pos_k: torch.Tensor,
pos_v, pos_v: torch.Tensor,
mask, mask: torch.Tensor,
relative_attention_bias: Optional[Tensor] = None, relative_attention_bias: Optional[Tensor] = None,
): ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""ConformerEncoder forward. """ConformerEncoder forward.
Args: Args:
x: torch.Tensor x: input feature of shape (batch, max_time_in, size)
input feature of shape (batch, max_time_in, size) pos_k: positional key embedding.
pos_k: torch.Tensor pos_v: positional value embedding.
positional key embedding. mask: mask for x (batch, max_time_in)
mask: torch.Tensor relative_attention_bias: bias added to attention logits w.r.t.
mask for x (batch, max_time_in) relative positions (1, n_head, time1, time2)
relative_attention_bias: Optional[torch.Tensor]
bias added to attention logits w.r.t. relative positions
(1, n_head, time1, time2)
""" """
x = x + 0.5 * self.feed_forward_in(x) x = x + 0.5 * self.feed_forward_in(x)
norm_x = self.layer_norm_att(x) norm_x = self.layer_norm_att(x)
...@@ -323,25 +320,25 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -323,25 +320,25 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
def __init__( def __init__(
self, self,
input_size, input_size: int,
chunk_size, chunk_size: Union[int, list[int]],
left_chunk, left_chunk: Union[int, list[int]],
attention_dim=256, attention_dim: int = 256,
attention_heads=4, attention_heads: int = 4,
input_layer="nemo_conv", input_layer: str = "nemo_conv",
cnn_out=-1, cnn_out: int = -1,
cnn_layer_norm=False, cnn_layer_norm: bool = False,
time_reduction=4, time_reduction: int = 4,
dropout_rate=0.0, dropout_rate: float = 0.0,
padding_idx=-1, padding_idx: int = -1,
relative_attention_bias_args=None, relative_attention_bias_args: Optional[dict[str, Any]] = None,
positional_dropout_rate=0.0, positional_dropout_rate: float = 0.0,
nemo_conv_settings=None, nemo_conv_settings: Optional[dict[str, Any]] = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none", conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none", True] = "none",
attention_group_size=1, attention_group_size: int = 1,
encoder_embedding_config=None, encoder_embedding_config: Optional[dict[str, Any]] = None,
): ) -> None:
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size
self.input_layer = input_layer self.input_layer = input_layer
...@@ -399,7 +396,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -399,7 +396,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
self.encoder_embedding = MeanVarianceNormLayer( self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"]) self.encoder_embedding_config["input_size"])
def compute_lens_change(self, feature_lens): def compute_lens_change(
self,
feature_lens: Union[int,
torch.Tensor]) -> Union[int, torch.Tensor]:
"""feature_lens: int """feature_lens: int
return updated feature lens. return updated feature lens.
...@@ -433,10 +433,14 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -433,10 +433,14 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
return ceil_func(feature_lens / self.time_reduction) return ceil_func(feature_lens / self.time_reduction)
@abc.abstractmethod @abc.abstractmethod
def forward(self): def forward(self) -> Any:
"""Abstract forward method implementation.""" """Abstract forward method implementation."""
def _chunk_size_selection(self, chunk_size=None, left_chunk=None): def _chunk_size_selection(
self,
chunk_size: Optional[Union[int, list[int]]] = None,
left_chunk: Optional[Union[int,
list[int]]] = None) -> tuple[int, int]:
"""If chunk size is a list, we will randomly select a chunk size.""" """If chunk size is a list, we will randomly select a chunk size."""
if chunk_size is None: if chunk_size is None:
...@@ -463,7 +467,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -463,7 +467,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
return chunk_size_train_eff, left_chunk_train_eff return chunk_size_train_eff, left_chunk_train_eff
def _get_embed_class(self, embed): def _get_embed_class(self, embed: nn.Module) -> nn.Module:
# pylint: disable=protected-access # pylint: disable=protected-access
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
...@@ -474,13 +478,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -474,13 +478,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
embed_class = embed.module embed_class = embed.module
return embed_class return embed_class
def _forward_embeddings_core(self, input_tensor, masks): def _forward_embeddings_core(
self, input_tensor: torch.Tensor,
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
embed_class = self._get_embed_class(self.embed) embed_class = self._get_embed_class(self.embed)
assert isinstance(embed_class, NemoConvSubsampling) assert isinstance(embed_class, NemoConvSubsampling)
input_tensor, masks = self.embed(input_tensor, masks) input_tensor, masks = self.embed(input_tensor, masks)
return input_tensor, masks return input_tensor, masks
def _position_embedding(self, input_tensor): def _position_embedding(
self, input_tensor: torch.Tensor
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
pos_k = None pos_k = None
pos_v = None pos_v = None
if self.relative_attention_bias_layer is None: if self.relative_attention_bias_layer is None:
...@@ -488,7 +496,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -488,7 +496,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
input_tensor) # default to add abs sinusoid embedding input_tensor) # default to add abs sinusoid embedding
return pos_k, pos_v return pos_k, pos_v
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): def _streaming_mask(self, seq_len: int, batch_size: int,
chunk_size: Union[int, list[int]],
left_chunk: Union[int, list[int]]) -> torch.Tensor:
chunk_size_train_eff, left_chunk_train_eff = \ chunk_size_train_eff, left_chunk_train_eff = \
self._chunk_size_selection(chunk_size, left_chunk) self._chunk_size_selection(chunk_size, left_chunk)
...@@ -502,11 +512,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -502,11 +512,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
[batch_size, -1, -1])) [batch_size, -1, -1]))
return enc_streaming_mask return enc_streaming_mask
def forward_embeddings(self, def forward_embeddings(
xs_pad, self,
masks, xs_pad: torch.Tensor,
chunk_size_nc=None, masks: torch.Tensor,
left_chunk_nc=None): chunk_size_nc: Optional[Union[int, list[int]]] = None,
left_chunk_nc: Optional[Union[int, list[int]]] = None
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor],
tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]]:
"""Forwarding the inputs through the top embedding layers """Forwarding the inputs through the top embedding layers
Args: Args:
...@@ -569,7 +585,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): ...@@ -569,7 +585,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
return input_tensor, pos_k, pos_v, hs_mask, masks return input_tensor, pos_k, pos_v, hs_mask, masks
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
def get_offset(self): def get_offset(self) -> int:
"""Returns offset used when retaining inputs for decoding. """Returns offset used when retaining inputs for decoding.
This is essentially, how many additional frames have to be added to This is essentially, how many additional frames have to be added to
...@@ -605,8 +621,6 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -605,8 +621,6 @@ class ConformerEncoder(TransformerEncoderBase):
Some examples for the 2 cases: Some examples for the 2 cases:
left_chunk = 6 left_chunk = 6
left_chunk = [12, 9, 6, 3] left_chunk = [12, 9, 6, 3]
left_chunk: int
number of chunks used for masking in streaming mode.
num_lang: int num_lang: int
This parameter is used to store the number of languages in the This parameter is used to store the number of languages in the
lang_dict, only used for multiseed/multilingual models. lang_dict, only used for multiseed/multilingual models.
...@@ -751,46 +765,46 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -751,46 +765,46 @@ class ConformerEncoder(TransformerEncoderBase):
def __init__( # pylint: disable-all def __init__( # pylint: disable-all
self, self,
input_size, input_size: int,
chunk_size, chunk_size: Union[int, list[int]],
left_chunk, left_chunk: Union[int, list[int]],
num_lang=None, num_lang: Optional[int] = None,
attention_dim=256, attention_dim: int = 256,
attention_heads=4, attention_heads: int = 4,
linear_units=2048, linear_units: int = 2048,
num_blocks=6, num_blocks: int = 6,
dropout_rate=0.1, dropout_rate: float = 0.1,
input_layer="nemo_conv", input_layer: str = "nemo_conv",
causal=True, causal: bool = True,
batch_norm=False, batch_norm: bool = False,
cnn_out=-1, cnn_out: int = -1,
cnn_layer_norm=False, cnn_layer_norm: bool = False,
ext_pw_out_channel=0, ext_pw_out_channel: int = 0,
ext_pw_kernel_size=1, ext_pw_kernel_size: int = 1,
depthwise_seperable_out_channel=256, depthwise_seperable_out_channel: int = 256,
depthwise_multiplier=1, depthwise_multiplier: int = 1,
chunk_se=0, chunk_se: int = 0,
kernel_size=3, kernel_size: int = 3,
activation="relu", activation: str = "relu",
conv_activation="relu", conv_activation: str = "relu",
conv_glu_type="sigmoid", conv_glu_type: str = "sigmoid",
bias_in_glu=True, bias_in_glu: bool = True,
linear_glu_in_convm=False, linear_glu_in_convm: bool = False,
attention_glu_type="swish", attention_glu_type: str = "swish",
export=False, export: bool = False,
extra_layer_output_idx=-1, extra_layer_output_idx: int = -1,
extra_multi_layer_output_idxs=[], # noqa extra_multi_layer_output_idxs: list[int] = [], # noqa
activation_checkpointing="", activation_checkpointing: str = "",
relative_attention_bias_args=None, relative_attention_bias_args: Optional[dict[str, Any]] = None,
time_reduction=4, time_reduction: int = 4,
use_pt_scaled_dot_product_attention=False, use_pt_scaled_dot_product_attention: bool = False,
nemo_conv_settings=None, nemo_conv_settings: Optional[dict[str, Any]] = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none", conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none", True] = "none",
replication_pad_for_subsample_embedding=False, replication_pad_for_subsample_embedding: bool = False,
attention_group_size=1, attention_group_size: int = 1,
encoder_embedding_config=None, encoder_embedding_config: Optional[dict[str, Any]] = None,
): ) -> None:
super().__init__( super().__init__(
input_size, input_size,
chunk_size, chunk_size,
...@@ -852,11 +866,13 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -852,11 +866,13 @@ class ConformerEncoder(TransformerEncoderBase):
# the device and the needed dtype: # the device and the needed dtype:
self.register_buffer("dev_type", torch.zeros(()), persistent=False) self.register_buffer("dev_type", torch.zeros(()), persistent=False)
def init_relative_attention_bias(self, input_tensor): def init_relative_attention_bias(
self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]:
if self.relative_attention_bias_layer: if self.relative_attention_bias_layer:
return self.relative_attention_bias_layer(input_tensor) return self.relative_attention_bias_layer(input_tensor)
def calculate_hs_mask(self, xs_pad, device, mask): def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device,
mask: Optional[torch.Tensor]) -> torch.Tensor:
max_audio_length = xs_pad.shape[1] max_audio_length = xs_pad.shape[1]
batch_size = xs_pad.shape[0] batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
...@@ -877,7 +893,8 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -877,7 +893,8 @@ class ConformerEncoder(TransformerEncoderBase):
return pad_mask return pad_mask
@torch.jit.ignore @torch.jit.ignore
def forward(self, xs_pad, masks): def forward(self, xs_pad: torch.Tensor,
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Conformer Forward function """Conformer Forward function
Args: Args:
...@@ -997,7 +1014,12 @@ class WindowQformer(nn.Module): ...@@ -997,7 +1014,12 @@ class WindowQformer(nn.Module):
if normalize_before else None) if normalize_before else None)
self.window_size = window_size self.window_size = window_size
def forward(self, audio_embed, mask, embed_len=None): def forward(
self,
audio_embed: torch.Tensor,
mask: Optional[torch.Tensor],
embed_len: Optional[int] = None
) -> tuple[torch.Tensor, Optional[int]]:
"""forward decoder""" """forward decoder"""
# audio_embed: N x T x D => N x D x T # audio_embed: N x T x D => N x D x T
...@@ -1042,7 +1064,7 @@ class WindowQformer(nn.Module): ...@@ -1042,7 +1064,7 @@ class WindowQformer(nn.Module):
class AudioEmbedding(nn.Module): class AudioEmbedding(nn.Module):
"""Image embedding.""" """Image embedding."""
def __init__(self, config: PretrainedConfig, **kwargs) -> None: def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
# n_embed or hidden_size for text LM # n_embed or hidden_size for text LM
...@@ -1148,19 +1170,18 @@ class AudioEmbedding(nn.Module): ...@@ -1148,19 +1170,18 @@ class AudioEmbedding(nn.Module):
self.input_embeds = None self.input_embeds = None
self.audio_embed_sizes = None self.audio_embed_sizes = None
def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: def set_audio_embeds(self, input_embeds: torch.Tensor) -> None:
self.input_embeds = input_embeds self.input_embeds = input_embeds
def set_audio_embed_sizes(self, def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None:
audio_embed_sizes: torch.LongTensor) -> None:
self.audio_embed_sizes = audio_embed_sizes self.audio_embed_sizes = audio_embed_sizes
def get_audio_features( def get_audio_features(
self, self,
input_embeds: torch.FloatTensor, input_embeds: torch.Tensor,
audio_attention_mask: torch.Tensor = None, audio_attention_mask: Optional[torch.Tensor] = None,
audio_projection_mode: str = "speech", audio_projection_mode: str = "speech",
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
arguments: arguments:
input_embeds: audio features (B, T, D) B: num audios in a sequence input_embeds: audio features (B, T, D) B: num audios in a sequence
...@@ -1214,10 +1235,10 @@ class AudioEmbedding(nn.Module): ...@@ -1214,10 +1235,10 @@ class AudioEmbedding(nn.Module):
def forward( def forward(
self, self,
audio_features: torch.FloatTensor, audio_features: torch.Tensor,
audio_attention_mask: torch.Tensor = None, audio_attention_mask: Optional[torch.Tensor] = None,
audio_projection_mode: str = "speech", audio_projection_mode: str = "speech",
) -> torch.FloatTensor: ) -> torch.Tensor:
""" """
arguments: arguments:
audio_features: audio features (T, D) audio_features: audio features (T, D)
......
This diff is collapsed.
...@@ -1193,21 +1193,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1193,21 +1193,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
positions: Flattened (concatenated) position ids corresponding to a positions: Flattened (concatenated) position ids corresponding to a
batch. batch. **NOTE**: If mrope is enabled (default setting for
**NOTE**: If mrope is enabled (default setting for Qwen2.5-VL Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`,
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
second_per_grid_ts: Tensor `(num_videos)` of video time interval (
in seconds) for each grid along the temporal dimension in the
3D position IDs. `None` if no videos are passed.
""" """
if intermediate_tensors is not None: if intermediate_tensors is not None:
......
...@@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers. ...@@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers.
""" """
from collections.abc import Iterable from collections.abc import Iterable
from itertools import cycle from itertools import cycle
from typing import Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module): ...@@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states: Input tensor [batch_size, seq_len, hidden_size] hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm) (one for conv, one for ssm)
sequence_idx: Index tensor for identifying sequences in batch
Required for proper chunked processing in prefill
transformer_hidden_states: Optional output from transformer path transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture) Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba) positions: Optional position IDs (unused in Mamba)
...@@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module): ...@@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module):
Args: Args:
shared_transformer: Transformer decoder layer for attention pathway shared_transformer: Transformer decoder layer for attention pathway
linear: Linear projection for transformer output before Mamba
mamba: Mamba decoder layer for state space pathway
""" """
super().__init__() super().__init__()
self.block_idx = block_idx self.block_idx = block_idx
...@@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module): ...@@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module):
positions: Position IDs for positional embeddings positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm) (one for conv, one for ssm)
sequence_idx: Indices for identifying sequences in batch,
required for proper chunked processing in prefill
Returns: Returns:
Output tensor combining transformer and Mamba representations Output tensor combining transformer and Mamba representations
...@@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
prefix: Optional prefix for parameter names prefix: Optional prefix for parameter names
Raises: Raises:
AssertionError: If prefix caching is enabled (not supported by AssertionError: If prefix caching is enabled
Mamba) (not supported by Mamba)
""" """
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -971,7 +965,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -971,7 +965,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor: **kwargs: Any) -> torch.Tensor:
"""Forward pass through the model. """Forward pass through the model.
Args: Args:
...@@ -1012,9 +1006,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -1012,9 +1006,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str, def copy_inputs_before_cuda_graphs(
torch.Tensor], self, input_buffers: dict[str, torch.Tensor],
**kwargs) -> dict[str, torch.Tensor]: **kwargs: Any) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture. """Copy inputs before CUDA graph capture.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment