Unverified Commit f2e9f2a3 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Remove redundant TypeVar from base model (#12248)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1f1542af
...@@ -3,7 +3,6 @@ from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, ...@@ -3,7 +3,6 @@ from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -19,9 +18,6 @@ if TYPE_CHECKING: ...@@ -19,9 +18,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
# The type of HF config
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)
# The type of hidden states # The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa # Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor] # which has T = List[torch.Tensor]
...@@ -34,7 +30,7 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) ...@@ -34,7 +30,7 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
@runtime_checkable @runtime_checkable
class VllmModel(Protocol[C_co, T_co]): class VllmModel(Protocol[T_co]):
"""The interface required for all models in vLLM.""" """The interface required for all models in vLLM."""
def __init__( def __init__(
...@@ -97,7 +93,7 @@ def is_vllm_model( ...@@ -97,7 +93,7 @@ def is_vllm_model(
@runtime_checkable @runtime_checkable
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]): class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
"""The interface required for all generative models in vLLM.""" """The interface required for all generative models in vLLM."""
def compute_logits( def compute_logits(
...@@ -143,7 +139,7 @@ def is_text_generation_model( ...@@ -143,7 +139,7 @@ def is_text_generation_model(
@runtime_checkable @runtime_checkable
class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]): class VllmModelForPooling(VllmModel[T], Protocol[T]):
"""The interface required for all pooling models in vLLM.""" """The interface required for all pooling models in vLLM."""
def pooler( def pooler(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment