Unverified Commit 2f8844ba authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Re-enable the 80 char line width limit (#3305)

parent 4b59f00e
...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter ...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
...@@ -50,7 +51,8 @@ class AWQConfig(QuantizationConfig): ...@@ -50,7 +51,8 @@ class AWQConfig(QuantizationConfig):
def get_config_filenames() -> List[str]: def get_config_filenames() -> List[str]:
return [ return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
"quantize_config.json",
] ]
@classmethod @classmethod
......
...@@ -31,8 +31,8 @@ class GPTQConfig(QuantizationConfig): ...@@ -31,8 +31,8 @@ class GPTQConfig(QuantizationConfig):
self.pack_factor = Fraction(32, self.weight_bits) self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]: if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError( raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is supported for " "Currently, only 2/3/4/8-bit weight quantization is "
f"GPTQ, but got {self.weight_bits} bits.") f"supported for GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, " return (f"GPTQConfig(weight_bits={self.weight_bits}, "
...@@ -101,7 +101,8 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -101,7 +101,8 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
...@@ -114,7 +115,8 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -114,7 +115,8 @@ class GPTQLinearMethod(LinearMethodBase):
exllama_state = ExllamaState.UNINITIALIZED exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None scale_and_zero_input_dim = None
if input_size != input_size_per_partition and self.quant_config.group_size != -1: if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer # For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act: if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED exllama_state = ExllamaState.UNUSED
......
...@@ -5,7 +5,8 @@ from torch.nn.parameter import Parameter ...@@ -5,7 +5,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class MarlinConfig(QuantizationConfig): class MarlinConfig(QuantizationConfig):
...@@ -22,8 +23,9 @@ class MarlinConfig(QuantizationConfig): ...@@ -22,8 +23,9 @@ class MarlinConfig(QuantizationConfig):
self.group_size = group_size self.group_size = group_size
if self.group_size != 128 and self.group_size != -1: if self.group_size != 128 and self.group_size != -1:
raise ValueError( raise ValueError(
"Currently, only group size 128 and -1 (channelwise) is supported for " "Currently, only group size 128 and -1 (channelwise) "
f"Marlin, but got group_size of {self.group_size}") "is supported for Marlin, but got group_size of "
f"{self.group_size}")
# 4 Bits packed into 32 bit datatype. # 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4 self.pack_factor = 32 // 4
...@@ -37,7 +39,8 @@ class MarlinConfig(QuantizationConfig): ...@@ -37,7 +39,8 @@ class MarlinConfig(QuantizationConfig):
# Min in_features dim # Min in_features dim
self.min_k_threads = 128 self.min_k_threads = 128
# Max parallel problems to solve at once (improves large batch performance) # Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16 self.max_parallel = 16
# Permutation length used by the marlin kernels. # Permutation length used by the marlin kernels.
...@@ -102,22 +105,26 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -102,22 +105,26 @@ class MarlinLinearMethod(LinearMethodBase):
# Validate output_size_per_partition # Validate output_size_per_partition
if output_size_per_partition % self.quant_config.min_n_threads != 0: if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError( raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}." f"Weight output_size_per_partition = "
) f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}." f"Weight output_size_per_partition = "
) f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition # Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0: if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError( raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}." f"Weight input_size_per_partition = "
) f"{input_size_per_partition} is not divisible by "
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0: f"min_k_threads = {self.quant_config.min_k_threads}.")
raise ValueError( if (self.quant_config.group_size != -1 and
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}." input_size_per_partition % self.quant_config.group_size != 0):
) raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard # Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // ( num_tiles_per_perm = self.quant_config.perm_len // (
...@@ -149,7 +156,9 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -149,7 +156,9 @@ class MarlinLinearMethod(LinearMethodBase):
) )
# Determine if channelwise or not # Determine if channelwise or not
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
scales = Parameter( scales = Parameter(
torch.empty( torch.empty(
......
...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter ...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import is_hip from vllm.utils import is_hip
......
...@@ -6,7 +6,8 @@ import torch.nn as nn ...@@ -6,7 +6,8 @@ import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput, SamplerOutput, SequenceData, SequenceGroupOutput,
......
...@@ -333,7 +333,8 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -333,7 +333,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if name == "lm_head.weight": if name == "lm_head.weight":
# Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to: # Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the # Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by # vocab size. This is suggested by
......
...@@ -119,7 +119,8 @@ class DeepseekMoE(nn.Module): ...@@ -119,7 +119,8 @@ class DeepseekMoE(nn.Module):
linear_method=None) linear_method=None)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekMLP( self.shared_experts = DeepseekMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
...@@ -273,8 +274,9 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -273,8 +274,9 @@ class DeepseekDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, linear_method=linear_method,
) )
if (config.n_routed_experts is not None and \ if (config.n_routed_experts is not None
layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, linear_method=linear_method) self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
else: else:
self.mlp = DeepseekMLP( self.mlp = DeepseekMLP(
......
...@@ -143,7 +143,8 @@ class GPTJBlock(nn.Module): ...@@ -143,7 +143,8 @@ class GPTJBlock(nn.Module):
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, linear_method) self.attn = GPTJAttention(config, linear_method)
self.mlp = GPTJMLP(inner_dim, config, linear_method) self.mlp = GPTJMLP(inner_dim, config, linear_method)
......
...@@ -305,7 +305,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -305,7 +305,8 @@ class InternLM2ForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
if "wqkv" in name: if "wqkv" in name:
config = self.config config = self.config
kv_groups = config.num_attention_heads // config.num_key_value_heads kv_groups = (config.num_attention_heads //
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups, loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim, head_dim,
......
...@@ -52,7 +52,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -52,7 +52,8 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, ) get_tensor_model_parallel_world_size, )
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -81,7 +82,8 @@ class SwiGLU(nn.Module): ...@@ -81,7 +82,8 @@ class SwiGLU(nn.Module):
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
""" """
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection). (plus another skip connection).
""" """
...@@ -94,11 +96,12 @@ class OlmoAttention(nn.Module): ...@@ -94,11 +96,12 @@ class OlmoAttention(nn.Module):
self.config = config self.config = config
self.hidden_size = config.d_model self.hidden_size = config.d_model
assert config.d_model % config.n_heads == 0 assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( tensor_model_parallel_world_size = (
) get_tensor_model_parallel_world_size())
self.total_num_heads = self.config.n_heads self.total_num_heads = self.config.n_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads self.head_dim = self.hidden_size // self.total_num_heads
# Layer norms. # Layer norms.
...@@ -158,7 +161,8 @@ class OlmoAttention(nn.Module): ...@@ -158,7 +161,8 @@ class OlmoAttention(nn.Module):
class OlmoMLP(nn.Module): class OlmoMLP(nn.Module):
""" """
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` This is the MLP block where the output is computed as
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection). (plus another skip connection).
""" """
...@@ -217,7 +221,8 @@ class OlmoMLP(nn.Module): ...@@ -217,7 +221,8 @@ class OlmoMLP(nn.Module):
class OlmoBlock(nn.Module): class OlmoBlock(nn.Module):
""" """
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection). (plus another skip connection).
""" """
......
...@@ -170,7 +170,8 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -170,7 +170,8 @@ class Qwen2DecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers use_sliding_window = (config.use_sliding_window
and layer_idx < config.max_window_layers)
self.self_attn = Qwen2Attention( self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
......
# coding=utf-8 # coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved. # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,7 +17,8 @@ ...@@ -16,7 +17,8 @@
# This code is based off the following work: # This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -102,8 +104,8 @@ class StablelmAttention(nn.Module): ...@@ -102,8 +104,8 @@ class StablelmAttention(nn.Module):
self.kv_size = self.num_key_value_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_bias = getattr(config, "use_qkv_bias", False) self.qkv_bias = getattr(config, "use_qkv_bias", False)
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError( raise ValueError(f"hidden_size must be divisible by num_heads "
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f"(got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).") f" and `num_heads`: {self.num_heads}).")
self.qkv_proj = QKVParallelLinear(self.hidden_size, self.qkv_proj = QKVParallelLinear(self.hidden_size,
...@@ -192,7 +194,6 @@ class StableLMEpochModel(nn.Module): ...@@ -192,7 +194,6 @@ class StableLMEpochModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None: linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__() super().__init__()
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
......
...@@ -34,7 +34,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -34,7 +34,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig, device_config: DeviceConfig, def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module: **kwargs) -> nn.Module:
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig from transformers_neuronx.config import (NeuronConfig,
ContinuousBatchingConfig)
parallel_config = kwargs.get("parallel_config") parallel_config = kwargs.get("parallel_config")
scheduler_config = kwargs.get("scheduler_config") scheduler_config = kwargs.get("scheduler_config")
......
...@@ -11,7 +11,8 @@ from vllm.model_executor.parallel_utils.parallel_state import ( ...@@ -11,7 +11,8 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
is_cupy_nccl_enabled_for_all_reduce, is_cupy_nccl_enabled_for_all_reduce,
) )
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce from vllm.model_executor.parallel_utils.custom_all_reduce import (
custom_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
......
...@@ -114,7 +114,8 @@ class SamplingTensors: ...@@ -114,7 +114,8 @@ class SamplingTensors:
do_penalties = True do_penalties = True
if (i < sampling_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get their logprobs # For tokens in the prompt that we only need to get
# their logprobs
prompt_len = sampling_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1) temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1)
......
...@@ -74,8 +74,8 @@ class SamplingParams: ...@@ -74,8 +74,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens. the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output include_stop_str_in_output: Whether to include the stop strings in
text. Defaults to False. output text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated. tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence. max_tokens: Maximum number of tokens to generate per output sequence.
......
...@@ -351,7 +351,8 @@ class SequenceGroup: ...@@ -351,7 +351,8 @@ class SequenceGroup:
self.metrics.first_token_time = time self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None: def maybe_set_first_scheduled_time(self, time: float) -> None:
"""Sets the first scheduled time and time in queue for Request level timings.""" """Sets the first scheduled time and time in queue for Request
level timings."""
if self.metrics.first_scheduled_time is None: if self.metrics.first_scheduled_time is None:
self.metrics.first_scheduled_time = time self.metrics.first_scheduled_time = time
self.metrics.time_in_queue = time - self.metrics.arrival_time self.metrics.time_in_queue = time - self.metrics.arrival_time
......
...@@ -5,8 +5,12 @@ import torch ...@@ -5,8 +5,12 @@ import torch
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData) from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores get_all_seq_ids,
split_batch_by_proposal_len)
from vllm.spec_decode.interfaces import (SpeculativeScorer,
SpeculativeProposals,
SpeculativeScores)
SeqId = int SeqId = int
TargetSeqId = int TargetSeqId = int
...@@ -68,7 +72,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -68,7 +72,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list = proposals.proposal_lens.tolist() proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist()
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch( (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list, proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list, proposal_lens_list=proposal_lens_list,
...@@ -125,7 +130,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -125,7 +130,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens = len(target_seq_group_metadata_list) num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs) target_seq_group_metadata_list.extend(non_spec_seqs)
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, original_bs: int, def _contract_batch(self, original_bs: int,
target_sampler_output: List[SamplerOutput], target_sampler_output: List[SamplerOutput],
...@@ -306,10 +312,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -306,10 +312,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# Convert non-speculative output tokens to tensors. # Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch( non_spec_target_token_ids, non_spec_target_probs = (
[sampler_output]) sampler_output_to_torch([sampler_output]))
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)
def _create_target_seq_id_iterator( def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
......
...@@ -5,7 +5,8 @@ import torch ...@@ -5,7 +5,8 @@ import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch from vllm.spec_decode.util import sampler_output_to_torch
...@@ -247,7 +248,8 @@ class DraftModelTop1Proposer(SpeculativeProposer): ...@@ -247,7 +248,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
""" """
# Split speculative- and non-speculative- sequences. # Split speculative- and non-speculative- sequences.
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len( (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices) = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len) seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs: if nonzero_proposal_len_seqs:
...@@ -306,7 +308,8 @@ class DraftModelTop1Proposer(SpeculativeProposer): ...@@ -306,7 +308,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
else: else:
proposal_lens.append(0) proposal_lens.append(0)
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices return (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices)
def _merge_outputs( def _merge_outputs(
self, self,
...@@ -356,7 +359,8 @@ class DraftModelTop1Proposer(SpeculativeProposer): ...@@ -356,7 +359,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
device=self._device) device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens = torch.zeros(batch_size, proposal_lens = torch.zeros(batch_size,
dtype=torch.long, dtype=torch.long,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment