# SPDX-License-Identifier: Apache-2.0 # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/vision.py from abc import ABC, abstractmethod from typing import Generic, Optional, TypeVar, Union import torch from transformers import PretrainedConfig from fastvideo.v1.logger import init_logger logger = init_logger(__name__) _C = TypeVar("_C", bound=PretrainedConfig) class VisionEncoderInfo(ABC, Generic[_C]): def __init__(self, vision_config: _C) -> None: super().__init__() self.vision_config = vision_config @abstractmethod def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: raise NotImplementedError @abstractmethod def get_max_image_tokens(self) -> int: raise NotImplementedError @abstractmethod def get_image_size(self) -> int: raise NotImplementedError @abstractmethod def get_patch_size(self) -> int: raise NotImplementedError @abstractmethod def get_patch_grid_length(self) -> int: raise NotImplementedError def resolve_visual_encoder_outputs( encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], feature_sample_layers: Optional[list[int]], post_layer_norm: Optional[torch.nn.LayerNorm], max_possible_layers: int, ) -> torch.Tensor: """Given the outputs a visual encoder module that may correspond to the output of the last layer, or a list of hidden states to be stacked, handle post normalization and resolve it into a single output tensor. Args: encoder_outputs: Output of encoder's last layer or all hidden states. feature_sample_layers: Optional layer indices to grab from the encoder outputs; if provided, encoder outputs must be a list. post_layer_norm: Post norm to apply to the output of the encoder. max_possible_layers: Total layers in the fully loaded visual encoder. """ if feature_sample_layers is None: if post_layer_norm is not None: return post_layer_norm(encoder_outputs) return encoder_outputs # Get the hidden states corresponding to the layer indices. # Negative values are relative to the full visual encoder, # so offset them depending on how many layers were loaded. # NOTE: this assumes that encoder_outputs is a list containing # the inputs to the visual encoder, followed by the hidden states # of each layer. num_loaded_layers = len(encoder_outputs) - 1 offset = max_possible_layers - num_loaded_layers hs_pool = [ encoder_outputs[layer_idx] if layer_idx >= 0 else encoder_outputs[layer_idx + offset] for layer_idx in feature_sample_layers ] # Apply post-norm on the final hidden state if we are using it uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) return torch.cat(hs_pool, dim=-1)