Unverified Commit df29793d authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[mypy][5/N] Support all typing on model executor (#4427)

parent 03dd7d52
...@@ -43,8 +43,8 @@ jobs: ...@@ -43,8 +43,8 @@ jobs:
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
# TODO(sang): Fix nested dir # TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
...@@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml ...@@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
......
...@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: ...@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return schema return schema
if isinstance(schema, BaseModel): if isinstance(schema, BaseModel):
return schema.model_json_schema() return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")
@lru_cache @lru_cache
......
...@@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module): ...@@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
if quant_config is None: if quant_config is None:
self.quant_method = UnquantizedLinearMethod() self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else: else:
self.quant_method = quant_config.get_quant_method(self) self.quant_method = quant_config.get_quant_method(self)
...@@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase): ...@@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size, self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size, [self.output_size], self.input_size,
self.output_size, self.params_dtype) self.output_size, self.params_dtype)
...@@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase): ...@@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
...@@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase): ...@@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
self.output_size_per_partition = divide(output_size, tp_size) self.output_size_per_partition = divide(output_size, tp_size)
if output_sizes is None: if output_sizes is None:
output_sizes = [output_size] output_sizes = [output_size]
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.quant_method.create_weights(self,
self.input_size, self.input_size,
[x // tp_size for x in output_sizes], [x // tp_size for x in output_sizes],
...@@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias) output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
...@@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase): ...@@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.quant_method.create_weights(self,
self.input_size_per_partition, self.input_size_per_partition,
[self.output_size], [self.output_size],
...@@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase): ...@@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel) output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
......
from typing import Type from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig ...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -76,8 +76,16 @@ class QuantizationConfig(ABC): ...@@ -76,8 +76,16 @@ class QuantizationConfig(ABC):
"quantization config.") "quantization config.")
@abstractmethod @abstractmethod
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase: def get_quant_method(
"""Get the quantize method to use for the quantized layer.""" self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
...@@ -52,11 +52,10 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -52,11 +52,10 @@ class SqueezeLLMConfig(QuantizationConfig):
return cls(weight_bits) return cls(weight_bits)
def get_quant_method( def get_quant_method(
self, self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self) return SqueezeLLMLinearMethod(self)
return return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
......
...@@ -431,8 +431,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): ...@@ -431,8 +431,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
torch.full_like(positions, k)).long() torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset) idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions) if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( self.long_short_cos_sin_cache: torch.Tensor = (
idx.device) self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
......
...@@ -13,6 +13,9 @@ from vllm.sampling_params import SamplingType ...@@ -13,6 +13,9 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceGroupOutput, SequenceOutput) SamplerOutput, SequenceGroupOutput, SequenceOutput)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
...@@ -155,7 +158,7 @@ def _apply_min_tokens_penalty( ...@@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
have not been generated yet have not been generated yet
""" """
# list of indices in logits that will be set to -inf # list of indices in logits that will be set to -inf
logits_to_penalize = [] logits_to_penalize: List[Tuple[int, int]] = []
logits_applied = 0 logits_applied = 0
for seq_group in sampling_metadata.seq_groups: for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
...@@ -269,7 +272,7 @@ def _apply_min_p( ...@@ -269,7 +272,7 @@ def _apply_min_p(
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor, samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run greedy sampling on a given samples. """Run greedy sampling on a given samples.
Args: Args:
...@@ -284,7 +287,7 @@ def _greedy_sample( ...@@ -284,7 +287,7 @@ def _greedy_sample(
""" """
samples = samples.tolist() samples = samples.tolist()
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
if not seq_group.do_sample: if not seq_group.do_sample:
results.append(([], [])) results.append(([], []))
...@@ -304,7 +307,7 @@ def _greedy_sample( ...@@ -304,7 +307,7 @@ def _greedy_sample(
def _random_sample( def _random_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor, random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run random sampling on a given samples. """Run random sampling on a given samples.
Args: Args:
...@@ -320,7 +323,7 @@ def _random_sample( ...@@ -320,7 +323,7 @@ def _random_sample(
# Find the maximum best_of value of the prompt phase requests. # Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
if not seq_group.do_sample: if not seq_group.do_sample:
results.append(([], [])) results.append(([], []))
...@@ -348,7 +351,7 @@ def _random_sample( ...@@ -348,7 +351,7 @@ def _random_sample(
def _beam_search_sample( def _beam_search_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor, logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run beam sampling on a given samples. """Run beam sampling on a given samples.
Args: Args:
...@@ -370,7 +373,7 @@ def _beam_search_sample( ...@@ -370,7 +373,7 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than # NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods. # other sampling methods.
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
if not seq_group.do_sample: if not seq_group.do_sample:
results.append(([], [])) results.append(([], []))
...@@ -391,16 +394,16 @@ def _beam_search_sample( ...@@ -391,16 +394,16 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
# Generation phase. # Generation phase.
cumulative_logprobs = [ cumulative_logprobs: List[int] = [
seq_group.seq_data[seq_id].cumulative_logprob seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids for seq_id in seq_ids
] ]
cumulative_logprobs = torch.tensor( cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs, cumulative_logprobs,
dtype=torch.float, dtype=torch.float,
device=seq_group_logprobs.device) device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs + seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1)) cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(), _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width) 2 * beam_width)
topk_ids = topk_ids.tolist() topk_ids = topk_ids.tolist()
...@@ -452,8 +455,10 @@ def _sample_with_torch( ...@@ -452,8 +455,10 @@ def _sample_with_torch(
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
...@@ -555,8 +560,10 @@ def _sample_with_triton_kernel( ...@@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors, sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
...@@ -632,7 +639,7 @@ def _sample( ...@@ -632,7 +639,7 @@ def _sample(
probs: torch.Tensor, logprobs: torch.Tensor, probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
""" """
Args: Args:
probs: (num_query_tokens_in_batch, num_vocab) probs: (num_query_tokens_in_batch, num_vocab)
...@@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: ...@@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_logprobs( def _get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]], sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs. """Return sample lobprobs and prompt logprobs.
...@@ -751,8 +758,8 @@ def _get_logprobs( ...@@ -751,8 +758,8 @@ def _get_logprobs(
assert len(next_token_ids) == len(query_indices) assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0: if len(query_indices) == 0:
empty_sampled_logprob = [] empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob = None empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob] return [empty_prompt_logprob], [empty_sampled_logprob]
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
...@@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, ...@@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output( def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]], sample_results: SampleResultType,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]], prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs], sample_logprobs: List[SampleLogprobs],
...@@ -1009,7 +1016,7 @@ def _build_sampler_output( ...@@ -1009,7 +1016,7 @@ def _build_sampler_output(
) )
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
"""Get a list of next prompt tokens to compute logprob from a """Get a list of next prompt tokens to compute logprob from a
given sequence group. given sequence group.
......
...@@ -64,7 +64,7 @@ class TensorizerConfig: ...@@ -64,7 +64,7 @@ class TensorizerConfig:
"s3_secret_access_key": self.s3_secret_access_key, "s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint, "s3_endpoint": self.s3_endpoint,
} }
return TensorizerArgs(**tensorizer_args) return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
...@@ -270,8 +270,10 @@ class TensorizerAgent: ...@@ -270,8 +270,10 @@ class TensorizerAgent:
self.model = self._init_model() self.model = self._init_model()
def _init_model(self): def _init_model(self):
assert self.tensorizer_config.hf_config is not None
model_args = self.tensorizer_config.hf_config model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
with no_init_or_tensor(): with no_init_or_tensor():
return self.tensorizer_config.model_class( return self.tensorizer_config.model_class(
config=model_args, config=model_args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment