Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -30,6 +30,9 @@ from torch import nn ...@@ -30,6 +30,9 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -43,13 +46,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -43,13 +46,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -367,6 +365,8 @@ class Qwen2MoeModel(nn.Module): ...@@ -367,6 +365,8 @@ class Qwen2MoeModel(nn.Module):
class Qwen2MoeForCausalLM(nn.Module): class Qwen2MoeForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -405,11 +405,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -405,11 +405,7 @@ class Qwen2MoeForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -420,12 +416,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -420,12 +416,7 @@ class Qwen2MoeForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
......
...@@ -19,13 +19,14 @@ ...@@ -19,13 +19,14 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -36,11 +37,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -36,11 +37,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -263,11 +261,7 @@ class StablelmForCausalLM(nn.Module): ...@@ -263,11 +261,7 @@ class StablelmForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -277,8 +271,7 @@ class StablelmForCausalLM(nn.Module): ...@@ -277,8 +271,7 @@ class StablelmForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if ("rotary_emb.cos_cached" in name if ("rotary_emb.cos_cached" in name
......
...@@ -18,13 +18,14 @@ ...@@ -18,13 +18,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from typing import List, Optional from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import Starcoder2Config from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
...@@ -35,11 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -35,11 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -275,11 +273,7 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -275,11 +273,7 @@ class Starcoder2ForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -288,8 +282,7 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -288,8 +282,7 @@ class Starcoder2ForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights.""" """Inference-only Xverse model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -28,6 +28,7 @@ from transformers import PretrainedConfig ...@@ -28,6 +28,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
...@@ -39,11 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -39,11 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -332,11 +330,7 @@ class XverseForCausalLM(nn.Module): ...@@ -332,11 +330,7 @@ class XverseForCausalLM(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [ stacked_params_mapping = [
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
...@@ -345,8 +339,7 @@ class XverseForCausalLM(nn.Module): ...@@ -345,8 +339,7 @@ class XverseForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in weights:
model_name_or_path, cache_dir, load_format, revision):
if ("rotary_emb.inv_freq" in name if ("rotary_emb.inv_freq" in name
or "rotary_emb.cos_cached" in name or "rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name): or "rotary_emb.sin_cached" in name):
......
The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference.
\ No newline at end of file
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
...@@ -113,6 +113,8 @@ class SamplingTensors: ...@@ -113,6 +113,8 @@ class SamplingTensors:
get_num_triton_sampler_splits(vocab_size)) get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0 sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None
assert sampling_metadata.seq_data is not None
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature temperature = sampling_params.temperature
...@@ -147,6 +149,7 @@ class SamplingTensors: ...@@ -147,6 +149,7 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get # For tokens in the prompt that we only need to get
# their logprobs # their logprobs
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1) temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1)
...@@ -172,6 +175,7 @@ class SamplingTensors: ...@@ -172,6 +175,7 @@ class SamplingTensors:
is_prompt = i < sampling_metadata.num_prompts is_prompt = i < sampling_metadata.num_prompts
if is_prompt: if is_prompt:
prompt_best_of.append(sampling_params.best_of) prompt_best_of.append(sampling_params.best_of)
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:
......
...@@ -112,8 +112,10 @@ class RequestOutput: ...@@ -112,8 +112,10 @@ class RequestOutput:
# always has the logprobs of the sampled tokens even if the # always has the logprobs of the sampled tokens even if the
# logprobs are not requested. # logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [ outputs = [
CompletionOutput(seqs.index(seq), seq.output_text, CompletionOutput(seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(), seq.get_output_token_ids(),
seq.get_cumulative_logprob(), seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None, seq.output_logprobs if include_logprobs else None,
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
import copy import copy
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from pydantic import Field
from typing_extensions import Annotated
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -88,11 +90,15 @@ class SamplingParams: ...@@ -88,11 +90,15 @@ class SamplingParams:
log probability of the sampled token, so there may be up to log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response. `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token. prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output. skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True. tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on logits_processors: List of functions that modify logits based on
previously generated tokens. previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
""" """
def __init__( def __init__(
...@@ -118,9 +124,11 @@ class SamplingParams: ...@@ -118,9 +124,11 @@ class SamplingParams:
min_tokens: int = 0, min_tokens: int = 0,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
...@@ -150,10 +158,22 @@ class SamplingParams: ...@@ -150,10 +158,22 @@ class SamplingParams:
self.min_tokens = min_tokens self.min_tokens = min_tokens
self.logprobs = logprobs self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self.detokenize = detokenize
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if self.stop and not include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verify_beam_search() self._verify_beam_search()
...@@ -210,6 +230,16 @@ class SamplingParams: ...@@ -210,6 +230,16 @@ class SamplingParams:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0: if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got " raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.") f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
def _verify_beam_search(self) -> None: def _verify_beam_search(self) -> None:
if self.best_of == 1: if self.best_of == 1:
...@@ -241,6 +271,18 @@ class SamplingParams: ...@@ -241,6 +271,18 @@ class SamplingParams:
raise ValueError("best_of must be 1 when using greedy sampling." raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.") f"Got {self.best_of}.")
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"):
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
@cached_property @cached_property
def sampling_type(self) -> SamplingType: def sampling_type(self) -> SamplingType:
if self.use_beam_search: if self.use_beam_search:
...@@ -290,4 +332,5 @@ class SamplingParams: ...@@ -290,4 +332,5 @@ class SamplingParams:
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, " f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens=" "spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})") f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
...@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum): ...@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum):
return finish_reason return finish_reason
class SequenceStage(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
@dataclass @dataclass
class RequestMetrics: class RequestMetrics:
"""Metrics associated with a request. """Metrics associated with a request.
...@@ -115,6 +120,7 @@ class SequenceData: ...@@ -115,6 +120,7 @@ class SequenceData:
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model). # The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0 self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id) self.output_token_ids.append(token_id)
...@@ -136,19 +142,25 @@ class SequenceData: ...@@ -136,19 +142,25 @@ class SequenceData:
"""Return the number of prefill tokens that are already computed.""" """Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens return self._num_computed_tokens
def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int: def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far.""" """Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
def reset_num_computed_tokens(self) -> None: def reset_state_for_recompute(self) -> None:
"""Reset the number of computed tokens from this sequence. It is """Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted). the beginning again (e.g., sequence is preempted).
""" """
self._num_computed_tokens = 0 self._num_computed_tokens = 0
self._stage = SequenceStage.PREFILL
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefil tokens that are not computed.""" """Return the number of prefill tokens that are not computed."""
# we use `get_len()` which includes prompt_len + output_len instead # we use `get_len()` which includes prompt_len + output_len instead
# of prompt_len here. This is because during recompute we need to # of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output. # prefill for both prompt and output.
...@@ -159,12 +171,16 @@ class SequenceData: ...@@ -159,12 +171,16 @@ class SequenceData:
return self.prompt_token_ids[-1] return self.prompt_token_ids[-1]
return self.output_token_ids[-1] return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int: def get_prompt_token_ids(self) -> List[int]:
return self.prompt_token_ids return self.prompt_token_ids
def get_output_token_ids(self) -> int: def get_output_token_ids(self) -> List[int]:
return self.output_token_ids return self.output_token_ids
@property
def stage(self) -> SequenceStage:
return self._stage
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceData(" return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
...@@ -219,6 +235,12 @@ class Sequence: ...@@ -219,6 +235,12 @@ class Sequence:
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
def get_output_text_to_return(self, buffer_length: int):
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size # TODO This can produce incorrect hash when block size > prompt size
...@@ -234,7 +256,7 @@ class Sequence: ...@@ -234,7 +256,7 @@ class Sequence:
def reset_state_for_recompute(self): def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation.""" """Reset the sequence states for recomputation."""
self.data.reset_num_computed_tokens() self.data.reset_state_for_recompute()
def _append_logical_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
...@@ -320,6 +342,20 @@ class Sequence: ...@@ -320,6 +342,20 @@ class Sequence:
new_seq.seq_id = new_seq_id new_seq.seq_id = new_seq_id
return new_seq return new_seq
def get_num_new_tokens(self) -> int:
"""Get the number of new tokens to be computed.
Returns:
The new number of tokens to be computed. I.e., 1 for decode, or
the remaining prompt size for prefill.
"""
if self.data.stage == SequenceStage.DECODE:
return 1
return self.data.get_num_uncomputed_tokens()
def is_prefill(self) -> bool:
return self.data.stage == SequenceStage.PREFILL
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, " return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, " f"status={self.status.name}, "
...@@ -331,7 +367,7 @@ class SequenceGroupState: ...@@ -331,7 +367,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group""" """Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling # torch.Generator used in seeded sampling
generator: Optional = None generator: Optional = None # type: ignore
class MultiModalData: class MultiModalData:
...@@ -461,16 +497,22 @@ class SequenceGroup: ...@@ -461,16 +497,22 @@ class SequenceGroup:
def update_num_computed_tokens(self, num_new_computed_tokens: int): def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far.""" """Update number of tokens computed so far."""
for seq in self.seqs_dict.values(): for seq in self.seqs_dict.values():
seq.data.update_num_computed_tokens(num_new_computed_tokens) if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
# All sequences in the group should have the same prompt, so the num_uncomputed_tokens = 0
# number of unfinished prefill tokens are the same across all for seq in self.get_seqs():
# sequences. if not seq.is_finished():
return list( num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
self.seqs_dict.values())[0].data.get_num_uncomputed_tokens() return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
# Optimization. We don't need to call get_seqs if we don't need to
# filter by states.
if status is None:
return len(self.seqs_dict)
return len(self.get_seqs(status)) return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int: def num_unfinished_seqs(self) -> int:
...@@ -497,6 +539,10 @@ class SequenceGroup: ...@@ -497,6 +539,10 @@ class SequenceGroup:
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.get_seqs()) return all(seq.is_finished() for seq in self.get_seqs())
def is_prefill(self) -> bool:
# Every sequences should be in the same stage.
return self.get_seqs()[0].is_prefill()
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, " return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, " f"sampling_params={self.sampling_params}, "
...@@ -513,8 +559,8 @@ class SequenceGroupMetadata: ...@@ -513,8 +559,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
token_chunk_size: The number of tokens to be processed. None if token_chunk_size: The number of tokens to be processed (per sequence).
chunking is not required. None if chunking is not required.
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
lora_request: LoRA request. lora_request: LoRA request.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
...@@ -555,7 +601,7 @@ class SequenceGroupMetadata: ...@@ -555,7 +601,7 @@ class SequenceGroupMetadata:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property @property
def token_chunk_size(self) -> int: def token_chunk_size(self) -> Optional[int]:
"""Return the number of tokens to be processed (chunk size).""" """Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size return self._token_chunk_size
...@@ -649,3 +695,16 @@ class SamplerOutput: ...@@ -649,3 +695,16 @@ class SamplerOutput:
def __eq__(self, other: object): def __eq__(self, other: object):
return isinstance(other, return isinstance(other,
self.__class__) and self.outputs == other.outputs self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
...@@ -9,7 +9,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, ...@@ -9,7 +9,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
sampler_output_to_torch, sampler_output_to_torch,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker from vllm.worker.worker_base import WorkerBase
SeqId = int SeqId = int
TargetSeqId = int TargetSeqId = int
...@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree. of topk/tree.
""" """
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker self._scorer_worker = scorer_worker
self._device = device self._device = device
self._vocab_size = vocab_size self._vocab_size = vocab_size
...@@ -71,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -71,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list = proposals.proposal_lens.tolist() proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist()
# Filter the list to ignore -1 proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if -1 not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list, (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch( num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list, proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list, proposal_lens_list=proposal_lens_list,
) )
...@@ -83,10 +90,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -83,10 +90,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
return_python_output=False) )
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch( all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list), contracted_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
num_scoring_tokens=num_scoring_tokens, num_scoring_tokens=num_scoring_tokens,
...@@ -103,7 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -103,7 +112,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _expand_batch( def _expand_batch(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId], proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int], proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding """Given the input sequences and potentially multiple corresponding
...@@ -125,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -125,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
select_proposal_len_zero=True) select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input( target_seq_group_metadata_list = self._create_scoring_model_input(
spec_seqs, proposal_token_ids_list) seq_group_metadata_list=spec_seqs,
proposal_token_ids=proposal_token_ids_list,
# NOTE: We determine the seq ids in the expanded batch using the
# full seq_group_metadata_list, instead of only spec_seqs.
target_seq_ids_iter=self._create_target_seq_id_iterator(
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
)
num_scoring_tokens = len(target_seq_group_metadata_list) num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs) target_seq_group_metadata_list.extend(non_spec_seqs)
return (spec_indices, non_spec_indices, target_seq_group_metadata_list, return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) num_scoring_tokens)
def _contract_batch(self, original_bs: int, def _contract_batch(self, contracted_bs: int,
target_sampler_output: List[SamplerOutput], target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int], num_scoring_tokens: int, non_spec_indices: List[int],
...@@ -141,6 +157,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -141,6 +157,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Contract the expanded batch back into its original size. """Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original This maps the scores of speculative tokens back to their original
sequences. sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
""" """
(target_token_ids, target_probs, non_spec_target_token_ids, (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output( non_spec_target_probs) = self._split_scoring_output(
...@@ -148,25 +167,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -148,25 +167,31 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# Map distinct sequences used to score each token # Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1]. # of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size, k = proposals.proposal_token_ids.shape expanded_batch_size, k = proposals.proposal_token_ids.shape
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape( target_token_ids = target_token_ids.squeeze().reshape(
batch_size, k + 1) spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(batch_size, k + 1, target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size) self._vocab_size)
all_tokens = torch.full(size=(original_bs, k + 1), all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1, fill_value=-1,
device=self._device, device=self._device,
dtype=torch.long) dtype=torch.long)
all_probs = torch.zeros(original_bs, all_probs = torch.zeros(contracted_bs,
k + 1, k + 1,
self._vocab_size, self._vocab_size,
device=self._device, device=self._device,
dtype=torch.float32) dtype=torch.float32)
if non_spec_indices: if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_probs[non_spec_indices, :1, :] = non_spec_target_probs
if spec_indices: if spec_indices:
...@@ -176,20 +201,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -176,20 +201,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return all_tokens, all_probs return all_tokens, all_probs
def _create_scoring_model_input( def _create_scoring_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft """Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring. model, create a list of target sequences that can be used for scoring.
target_seq_ids_iter provides sequence ids for the expanded batch,
fulfilling the requirement that no seq id in the expanded batch is equal
to the seq id in the original batch.
""" """
if not seq_group_metadata_list: if not seq_group_metadata_list:
return [] return []
target_seq_ids_iter = self._create_target_seq_id_iterator(
get_all_seq_ids(seq_group_metadata_list))
target_seq_group_metadata = list( target_seq_group_metadata = list(
chain.from_iterable( chain.from_iterable(
self._create_target_seq_group_metadata( self._create_target_seq_group_metadata(
...@@ -205,7 +232,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -205,7 +232,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _create_target_seq_group_metadata( def _create_target_seq_group_metadata(
self, self,
input_seq_group_metadata: SequenceGroupMetadata, input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k] proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int, batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId], target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
...@@ -347,7 +374,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -347,7 +374,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
[0, 1, 2] [0, 1, 2]
[0, 1, 2, 3] [0, 1, 2, 3]
""" """
empty_token_ids = [] empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids] token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([ token_ids_to_score.extend([
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
import torch import torch
...@@ -24,9 +24,9 @@ class SpeculativeProposals: ...@@ -24,9 +24,9 @@ class SpeculativeProposals:
def __repr__(self): def __repr__(self):
return (f"SpeculativeProposals(" return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids.shape}, " f"proposal_token_ids={self.proposal_token_ids}, "
f"proposal_probs={self.proposal_probs.shape}, " f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens.shape})") f"proposal_lens={self.proposal_lens})")
@dataclass @dataclass
...@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC): ...@@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
blocks_to_copy: Optional[Dict[int, List[int]]], blocks_to_copy: Optional[Dict[int, List[int]]],
k: int, k: int,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> SpeculativeScores:
raise NotImplementedError raise NotImplementedError
...@@ -112,6 +112,7 @@ class AsyncMetricsCollector: ...@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
Returns a CUDA event recording when the copy is complete. Returns a CUDA event recording when the copy is complete.
""" """
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream()) self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream): with torch.cuda.stream(self._copy_stream):
...@@ -146,15 +147,16 @@ class AsyncMetricsCollector: ...@@ -146,15 +147,16 @@ class AsyncMetricsCollector:
emitted_tokens = self._aggregate_num_emitted_tokens.item() emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens draft_tokens = self._aggregate_num_draft_tokens
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k) max_num_emitted_tokens = self.get_max_num_emitted_tokens(
draft_tokens, k)
if draft_tokens > 0: if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens draft_acceptance_rate = accepted_tokens / draft_tokens
else: else:
draft_acceptance_rate = float("nan") draft_acceptance_rate = float("nan")
if num_possible_tokens > 0: if max_num_emitted_tokens > 0:
system_efficiency = emitted_tokens / num_possible_tokens system_efficiency = emitted_tokens / max_num_emitted_tokens
else: else:
system_efficiency = float("nan") system_efficiency = float("nan")
...@@ -168,8 +170,22 @@ class AsyncMetricsCollector: ...@@ -168,8 +170,22 @@ class AsyncMetricsCollector:
) )
@staticmethod @staticmethod
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int: def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
# Divide by k since batch size can be variable. """Calculate the number of emitted tokens, assuming all tokens are
total_num_spec_seqs = draft_tokens / k accepted.
num_accepted_per_seq_if_all_accepted = k + 1
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted) This is equal to the number of sequences that have been speculated on,
times (speculation len + 1). The +1 comes from the bonus token.
"""
# Determine the number of sequences that have been speculated on. Since
# the batch size can be variable, we divide by k.
assert draft_tokens % k == 0
total_num_spec_seqs = draft_tokens // k
# A single sequence may emit k accepted tokens and one bonus token in
# the best case.
num_emitted_per_seq_if_all_accepted = k + 1
# The max num of emitted tokens is the number of speculated sequences
# times the max emitted per seq.
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
...@@ -25,7 +25,8 @@ class MultiStepWorker(Worker): ...@@ -25,7 +25,8 @@ class MultiStepWorker(Worker):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._proposer: Optional[DraftModelTop1Proposer] = None # Lazy initialization list.
self._proposer: DraftModelTop1Proposer
def init_device(self): def init_device(self):
super().init_device() super().init_device()
...@@ -69,6 +70,9 @@ class MultiStepWorker(Worker): ...@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
) )
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._append_new_tokens(model_output, self._append_new_tokens(model_output,
copied_seq_group_metadata_list) copied_seq_group_metadata_list)
...@@ -324,23 +328,25 @@ class DraftModelTop1Proposer(SpeculativeProposer): ...@@ -324,23 +328,25 @@ class DraftModelTop1Proposer(SpeculativeProposer):
""" """
if maybe_sampler_output is None: if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None. # If no speculative tokens, the sampler output will be None.
# In this case we return empty tensors. # In this case we return empty proposals.
proposal_tokens = torch.zeros(0, proposal_tokens = torch.full(size=(
max_proposal_len, batch_size,
dtype=torch.long, max_proposal_len,
device=self._device) ),
proposal_probs = torch.zeros(0, fill_value=-1,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(batch_size,
max_proposal_len, max_proposal_len,
self._vocab_size, self._vocab_size,
dtype=torch.float32, dtype=torch.float32,
device=self._device) device=self._device)
proposal_lens = torch.zeros(len(proposal_lens), proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long, dtype=torch.long,
device=self._device) device=self._device)
return proposal_tokens, proposal_probs, proposal_lens return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch( proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output) sampler_output)
...@@ -362,9 +368,9 @@ class DraftModelTop1Proposer(SpeculativeProposer): ...@@ -362,9 +368,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
proposal_tokens, proposal_probs = (entire_proposal_tokens, proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs) entire_proposal_probs)
proposal_lens = torch.zeros(batch_size, proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long, dtype=torch.long,
device=self._device) device=self._device)
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens return proposal_tokens, proposal_probs, proposal_lens_tensor
...@@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Tuple ...@@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm.config import CacheConfig from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput) SequenceGroupOutput, SequenceOutput)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
...@@ -14,10 +14,12 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector ...@@ -14,10 +14,12 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
class SpecDecodeWorker:
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding. """Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal Speculative decoding reduces decoding per-token latency by using a proposal
...@@ -45,10 +47,20 @@ class SpecDecodeWorker: ...@@ -45,10 +47,20 @@ class SpecDecodeWorker:
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
""" """
@classmethod
def from_workers(cls, proposer_worker: MultiStepWorker,
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
# TODO(cade) disable strict mode for speedup.
rejection_sampler=RejectionSampler(strict_mode=True),
)
def __init__( def __init__(
self, self,
proposer_worker: MultiStepWorker, proposer_worker: MultiStepWorker,
scorer_worker: Worker, scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler, rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
): ):
...@@ -77,7 +89,8 @@ class SpecDecodeWorker: ...@@ -77,7 +89,8 @@ class SpecDecodeWorker:
self.probs_dtype = self.rejection_sampler.probs_dtype self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.scorer: SpeculativeScorer = None # Lazy initiazliation.
self.scorer: SpeculativeScorer
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
...@@ -87,6 +100,10 @@ class SpecDecodeWorker: ...@@ -87,6 +100,10 @@ class SpecDecodeWorker:
self.scorer_worker.init_device() self.scorer_worker.init_device()
self.proposer_worker.init_device() self.proposer_worker.init_device()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self.scorer_worker.load_model()
self.proposer_worker.load_model()
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank) self.rejection_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer( self.scorer = BatchExpansionTop1Scorer(
...@@ -94,10 +111,33 @@ class SpecDecodeWorker: ...@@ -94,10 +111,33 @@ class SpecDecodeWorker:
device=self.device, device=self.device,
vocab_size=self._vocab_size) vocab_size=self._vocab_size)
def profile_num_available_blocks(self, block_size: int, self._configure_model_sampler_for_spec_decode()
gpu_memory_utilization: float,
cpu_swap_space: int, def _configure_model_sampler_for_spec_decode(self):
cache_dtype: str) -> Tuple[int, int]: """Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of rejection sampling.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
done outside of the model/sampler; this way the "last-mile" worker
object which interfaces with the scheduler can serialize and incur the
performance hit as necessary. This allows us to run the worker several
iterations in a row without incurring the "move to CPU and serialize"
performance penalty.
Since this requires a large change to vLLM, we defer it to later and
temporarily accept this broken abstraction boundary.
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.proposer_worker.model_runner.model.sampler.
include_gpu_probs_tensor) = True
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use. """Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the This is done by profiling the scorer model (which is typically the
...@@ -106,27 +146,26 @@ class SpecDecodeWorker: ...@@ -106,27 +146,26 @@ class SpecDecodeWorker:
such that the number of blocks is equal in both KV caches. such that the number of blocks is equal in both KV caches.
""" """
num_gpu_blocks, num_cpu_blocks = ( num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.profile_num_available_blocks( self.scorer_worker.determine_num_available_blocks())
block_size, gpu_memory_utilization, cpu_swap_space,
cache_dtype))
scorer_cache_block_size_bytes = ( scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes( self.scorer_worker.get_cache_block_size_bytes())
block_size, cache_dtype))
proposer_cache_block_size_bytes = ( proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes( self.proposer_worker.get_cache_block_size_bytes())
block_size, cache_dtype))
new_num_gpu_blocks = split_num_cache_blocks_evenly( new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks) num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks return new_num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig): def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the cache engine of the scorer and proposer workers. """Initialize the cache engine of the scorer and proposer workers.
""" """
self.scorer_worker.init_cache_engine(cache_config) self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
self.proposer_worker.init_cache_engine(cache_config) num_cpu_blocks=num_cpu_blocks)
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -135,7 +174,7 @@ class SpecDecodeWorker: ...@@ -135,7 +174,7 @@ class SpecDecodeWorker:
blocks_to_swap_in: Optional[Dict[int, int]], blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]], blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]], blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int, num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. """Perform speculative decoding on the input batch.
""" """
...@@ -144,9 +183,11 @@ class SpecDecodeWorker: ...@@ -144,9 +183,11 @@ class SpecDecodeWorker:
"speculative decoding " "speculative decoding "
"requires non-None seq_group_metadata_list") "requires non-None seq_group_metadata_list")
logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
# If no spec tokens, call the proposer and scorer workers normally. # If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill. # Used for prefill.
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0: if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec( return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
...@@ -159,7 +200,7 @@ class SpecDecodeWorker: ...@@ -159,7 +200,7 @@ class SpecDecodeWorker:
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
k=num_spec_tokens, k=num_lookahead_slots,
) )
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
...@@ -174,20 +215,24 @@ class SpecDecodeWorker: ...@@ -174,20 +215,24 @@ class SpecDecodeWorker:
proposer and scorer model so that the KV cache is consistent between the proposer and scorer model so that the KV cache is consistent between the
two. two.
""" """
logger.info("run proposer worker no spec")
self.proposer_worker.execute_model( self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
return_python_output=False) )
logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model( sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
) )
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Clear device tensors from sampler output. This reduces communication # Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers. # overhead when the engine runs in a different process than the workers.
...@@ -213,11 +258,16 @@ class SpecDecodeWorker: ...@@ -213,11 +258,16 @@ class SpecDecodeWorker:
sequence. sequence.
""" """
logger.info("get spec proposals")
# Generate proposals using draft worker. # Generate proposals using draft worker.
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals( proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k) blocks_to_copy, k)
logger.info("score proposals")
proposal_scores = self.scorer.score_proposals( proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list, seq_group_metadata_list,
blocks_to_swap_in, blocks_to_swap_in,
...@@ -227,9 +277,11 @@ class SpecDecodeWorker: ...@@ -227,9 +277,11 @@ class SpecDecodeWorker:
proposals, proposals,
) )
logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list, accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k) proposal_scores, proposals, k)
logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list, return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k) accepted_token_ids, k)
...@@ -260,15 +312,26 @@ class SpecDecodeWorker: ...@@ -260,15 +312,26 @@ class SpecDecodeWorker:
select_proposal_len_zero=True) select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices original_indices = spec_indices + non_spec_indices
proposal_probs = proposal_scores.probs[spec_indices, :-1] # Get probabilities of target model, excluding bonus token.
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs[spec_indices]
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
accepted_token_ids = self.rejection_sampler( accepted_token_ids = self.rejection_sampler(
proposal_probs, target_probs=proposal_verifier_probs,
bonus_token_ids, bonus_token_ids=bonus_token_ids,
proposals.proposal_probs, draft_probs=proposal_probs,
proposals.proposal_token_ids, draft_token_ids=proposal_token_ids,
) )
# Append output tokens from non-speculative sequences to # Append output tokens from non-speculative sequences to
...@@ -315,7 +378,7 @@ class SpecDecodeWorker: ...@@ -315,7 +378,7 @@ class SpecDecodeWorker:
parent_seq_id=seq_id, parent_seq_id=seq_id,
output_token=token_id, output_token=token_id,
# TODO Add verifier logprobs. # TODO Add verifier logprobs.
logprobs={token_id: 0.0}, logprobs={token_id: Logprob(0.0)},
) )
], ],
prompt_logprobs=None, prompt_logprobs=None,
...@@ -351,6 +414,16 @@ class SpecDecodeWorker: ...@@ -351,6 +414,16 @@ class SpecDecodeWorker:
def device(self): def device(self):
return self.scorer_worker.device return self.scorer_worker.device
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes.
This function is only used to compose workers within a SpecDecodeWorker.
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
undefined for now, although it could be implemented in the future.
See https://arxiv.org/abs/2308.04623.
"""
raise NotImplementedError
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
proposer_cache_block_size_bytes: int, proposer_cache_block_size_bytes: int,
......
...@@ -82,6 +82,32 @@ def sampler_output_to_torch( ...@@ -82,6 +82,32 @@ def sampler_output_to_torch(
return sampled_token_ids, sampled_token_probs return sampled_token_ids, sampled_token_probs
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
vocab_size: int, device: str) -> None:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values = [
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
]
assert all(v is None for v in values) or not any(v is None for v in values)
if not any(v is None for v in values):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
dim=-1)
sampler_output.sampled_token_ids = torch.randint(low=10,
high=100,
size=(batch_size, ),
dtype=torch.long,
device=device)
@contextmanager @contextmanager
def nvtx_range(msg, *args, **kwargs): def nvtx_range(msg, *args, **kwargs):
""" """
......
import ray import ray
from vllm.config import ParallelConfig from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_open_port from vllm.utils import get_open_port
from vllm.worker.worker import init_distributed_environment
def init_test_distributed_environment( def init_test_distributed_environment(
...@@ -12,15 +12,14 @@ def init_test_distributed_environment( ...@@ -12,15 +12,14 @@ def init_test_distributed_environment(
distributed_init_port: str, distributed_init_port: str,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}" distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment( init_distributed_environment(
parallel_config, world_size=pipeline_parallel_size * tensor_parallel_size,
rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
local_rank=local_rank) local_rank=local_rank)
ensure_model_parallel_initialized(tensor_parallel_size,
pipeline_parallel_size)
def multi_process_tensor_parallel( def multi_process_tensor_parallel(
......
from typing import Optional from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)
_CONFIG_REGISTRY = { _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
......
...@@ -12,7 +12,7 @@ from transformers.utils import logging ...@@ -12,7 +12,7 @@ from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment