Commit 7c4f76e3 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.0

parents 2da0dd3e 51c31bc1
from typing import List, Optional
import time
from typing import List, Optional, Union
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus, RequestMetrics)
from vllm.lora.request import LoRARequest
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
class CompletionOutput:
......@@ -18,6 +18,9 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
stop_reason: The stop string or token id that caused the completion
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
"""
......@@ -29,6 +32,7 @@ class CompletionOutput:
cumulative_logprob: float,
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
......@@ -37,6 +41,7 @@ class CompletionOutput:
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request
def finished(self) -> bool:
......@@ -48,7 +53,8 @@ class CompletionOutput:
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason})")
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")
class RequestOutput:
......@@ -87,32 +93,33 @@ class RequestOutput:
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
if seq_group.sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
if len(seqs) == 1:
top_n_seqs = seqs
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
# Get the top-n sequences.
n = seq_group.sampling_params.n
if seq_group.sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
# Create the outputs.
outputs: List[CompletionOutput] = []
for seq in top_n_seqs:
logprobs = seq.output_logprobs
if seq_group.sampling_params.logprobs is None:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs = None
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
seq.get_cumulative_logprob(), logprobs,
finshed_reason)
outputs.append(output)
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None
outputs = [
CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
]
# Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt
......
from typing import Dict, List, Sequence, Tuple, Optional
from vllm.block import BlockTable
class Prefix:
"""Data and states associated with a prefix of prompt tokens for multiple
sequence groups.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
token_ids: The token ids of the prefix.
block_size: The block size of the executed model.
"""
def __init__(
self,
token_ids: Sequence[int],
block_size: int,
) -> None:
self.token_ids = tuple(token_ids)
self.block_size = block_size
self.length = len(token_ids)
self.hash = hash(token_ids)
assert self.length % block_size == 0
self.block_table: Optional[BlockTable] = None
self.computed = False
@property
def allocated(self) -> bool:
return self.block_table is not None
def get_num_blocks(self) -> int:
return self.length // self.block_size
def get_block_numbers(self) -> List[int]:
return [block.block_number for block in self.block_table]
def get_length(self) -> int:
return self.length
def __hash__(self) -> int:
return self.hash
def set_block_table(self, block_table: BlockTable) -> None:
self.block_table = block_table.copy()
class PrefixPool:
"""Manages all the prompt prefixes.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
block_size: The block size of the executed model.
Attributes:
prefixes: A list of all the prefixes.
block_size: The block size of the executed model.
"""
def __init__(
self,
block_size: int,
) -> None:
# TODO(zhuohan): Add a capacity limit to the prefix pool.
self.prefixes: Dict[int, Prefix] = {}
self.block_size = block_size
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
new_length = len(token_ids) // self.block_size * self.block_size
return tuple(token_ids[:new_length])
def add_or_get_prefix(self, token_ids: Sequence[int],
lora_int_id: int) -> Optional[Prefix]:
token_ids = self._truncate_token_ids(token_ids)
if len(token_ids) == 0:
# Prefix is empty.
return None
prefix = Prefix(token_ids, self.block_size)
prefix_hash = hash((prefix, lora_int_id))
if prefix_hash not in self.prefixes:
self.prefixes[prefix_hash] = prefix
return self.prefixes[prefix_hash]
......@@ -74,11 +74,13 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
include_stop_str_in_output: Whether to include the stop strings in
output text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely
......@@ -113,6 +115,7 @@ class SamplingParams:
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
......@@ -144,6 +147,7 @@ class SamplingParams:
self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
......@@ -161,6 +165,8 @@ class SamplingParams:
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# injected by the engine
self.eos_token_id = None
def _verify_args(self) -> None:
if self.n < 1:
......@@ -191,6 +197,13 @@ class SamplingParams:
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
......@@ -272,6 +285,7 @@ class SamplingParams:
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"min_tokens={self.min_tokens}, "
f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
......
......@@ -2,15 +2,34 @@
import copy
import enum
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock
from vllm.prefix import Prefix
from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]
if TYPE_CHECKING:
import torch
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs and token ranks.
Attributes:
logprob: The logprob of chosen token
rank: The vocab rank of chosen token (>=1)
decoded_token: The decoded chosen token index
"""
logprob: float
rank: Optional[int] = None
decoded_token: Optional[str] = None
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
SampleLogprobs = List[Dict[int, Logprob]]
class SequenceStatus(enum.Enum):
......@@ -54,7 +73,7 @@ class SequenceStatus(enum.Enum):
class RequestMetrics:
"""Metrics associated with a request.
Args:
Attributes:
arrival_time: The time when the request arrived.
first_scheduled_time: The time when the request was first scheduled.
first_token_time: The time when the first token was generated.
......@@ -74,6 +93,8 @@ class SequenceData:
Args:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output. Set to an empty list if
None.
Attributes:
prompt_token_ids: The token IDs of the prompt.
......@@ -84,10 +105,16 @@ class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
output_token_ids = []
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
......@@ -105,11 +132,39 @@ class SequenceData:
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
def get_num_computed_tokens(self) -> int:
"""Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens
def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int:
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
def reset_num_computed_tokens(self) -> None:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted).
"""
self._num_computed_tokens = 0
def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefil tokens that are not computed."""
# we use `get_len()` which includes prompt_len + output_len instead
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
return self.get_len() - self.get_num_computed_tokens()
def get_last_token_id(self) -> int:
if not self.output_token_ids:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int:
return self.prompt_token_ids
def get_output_token_ids(self) -> int:
return self.output_token_ids
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, "
......@@ -135,11 +190,13 @@ class Sequence:
prompt: str,
prompt_token_ids: List[int],
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids)
......@@ -150,6 +207,7 @@ class Sequence:
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
# Used for incremental detokenization
self.prefix_offset = 0
......@@ -161,6 +219,23 @@ class Sequence:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation."""
self.data.reset_num_computed_tokens()
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
......@@ -187,12 +262,12 @@ class Sequence:
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, float],
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id])
self.data.append_token_id(token_id, logprobs[token_id].logprob)
def get_len(self) -> int:
return self.data.get_len()
......@@ -206,6 +281,9 @@ class Sequence:
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
def get_prompt_token_ids(self) -> List[int]:
return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
......@@ -256,6 +334,25 @@ class SequenceGroupState:
generator: Optional = None
class MultiModalData:
"""Multi modal request.
Args:
type: The data type.
data: The actual data.
The required shape and semantic meaning of it depends on the vision
language config of the hosted model.
See `VisionLanguageConfig` in `config.py`.
"""
class Type(enum.Enum):
IMAGE = enum.auto()
def __init__(self, type: Type, data: "torch.Tensor"):
self.type = type
self.data = data
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
......@@ -265,7 +362,7 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
multi_modal_data: Multi modal data associated with the request.
"""
def __init__(
......@@ -275,7 +372,7 @@ class SequenceGroup:
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
......@@ -286,9 +383,9 @@ class SequenceGroup:
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.multi_modal_data = multi_modal_data
@property
def prompt(self) -> str:
......@@ -318,7 +415,8 @@ class SequenceGroup:
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None:
"""Sets the first scheduled time and time in queue for Request level timings."""
"""Sets the first scheduled time and time in queue for Request
level timings."""
if self.metrics.first_scheduled_time is None:
self.metrics.first_scheduled_time = time
self.metrics.time_in_queue = time - self.metrics.arrival_time
......@@ -348,12 +446,9 @@ class SequenceGroup:
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
if status is None:
return list(self.seqs_dict.values())
else:
return [
seq for seq in self.seqs_dict.values() if seq.status == status
]
return list(self.seqs_dict.values()) if status is None else [
seq for seq in self.seqs_dict.values() if seq.status == status
]
def get_unfinished_seqs(self) -> List[Sequence]:
return [
......@@ -363,6 +458,18 @@ class SequenceGroup:
def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
for seq in self.seqs_dict.values():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int:
# All sequences in the group should have the same prompt, so the
# number of unfinished prefill tokens are the same across all
# sequences.
return list(
self.seqs_dict.values())[0].data.get_num_uncomputed_tokens()
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
......@@ -397,7 +504,7 @@ class SequenceGroup:
class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
Args:
request_id: The ID of the request.
......@@ -406,9 +513,11 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
token_chunk_size: The number of tokens to be processed. None if
chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
multi_modal_data: Multi modal data.
"""
def __init__(
......@@ -418,9 +527,11 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
......@@ -428,13 +539,26 @@ class SequenceGroupMetadata:
self.sampling_params = sampling_params
self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size
if self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len()
else:
self._token_chunk_size = 1
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size
class SequenceOutput:
"""The model output associated with a sequence.
......@@ -451,7 +575,7 @@ class SequenceOutput:
self,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float],
logprobs: Dict[int, Logprob],
) -> None:
self.parent_seq_id = parent_seq_id
self.output_token = output_token
......@@ -465,9 +589,10 @@ class SequenceOutput:
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutput):
raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
equal = (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token)
log_probs_equal = other.logprobs == self.logprobs
return equal and log_probs_equal
class SequenceGroupOutput:
......@@ -492,6 +617,35 @@ class SequenceGroupOutput:
and self.prompt_logprobs == other.prompt_logprobs)
# For each sequence group, we generate a list of SequenceOutput object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[SequenceGroupOutput]
@dataclass
class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This datastructure implements methods so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[SequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
from itertools import chain, count
from typing import Dict, Iterator, List, Optional, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
sampler_output_to_torch,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
SeqId = int
TargetSeqId = int
TokenId = int
class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
probabilities of speculative tokens according to the scoring model.
Batch expansion converts a list of sequences and multiple query positions
to a new batch of sequences, each with a single query position. This allows
for MQA-like scoring in speculative decoding without requiring an MQA
kernel.
It is strictly less efficient than MQA scoring.
It only supports scoring the top1 proposal tokens of the proposer, instead
of topk/tree.
"""
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
seq_group_metadata_list: The input sequence group metadata.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
# TODO(cade) perform this on GPU to remove blocking call.
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=k,
)
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
)
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
proposal tokens, create a new batch where each sequence has a single
query token.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input(
spec_seqs, proposal_token_ids_list)
num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs)
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, original_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
"""
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
batch_size, k = proposals.proposal_token_ids.shape
target_token_ids = target_token_ids.squeeze().reshape(
batch_size, k + 1)
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
self._vocab_size)
all_tokens = torch.full(size=(original_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(original_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
return all_tokens, all_probs
def _create_scoring_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
"""
if not seq_group_metadata_list:
return []
target_seq_ids_iter = self._create_target_seq_id_iterator(
get_all_seq_ids(seq_group_metadata_list))
target_seq_group_metadata = list(
chain.from_iterable(
self._create_target_seq_group_metadata(
seq_group_metadata,
proposal_token_ids,
i,
target_seq_ids_iter,
) for i, seq_group_metadata in enumerate(
seq_group_metadata_list)))
return target_seq_group_metadata
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given an input sequence group metadata and a list of draft tokens,
create a list of target SequenceGroupMetadata, one for each
token id that needs to be scored.
Naive speculative decoding requires K target model scores, one for each
draft model token. However one can add a bonus token such that if each
token is accepted, then a final token may be sampled from the model.
This function creates K+1 target SequenceGroupMetadata to take
advantage of the bonus token.
"""
assert not input_seq_group_metadata.is_prompt, (
"Speculating on "
"prompts not yet supported")
assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for token_ids in token_ids_to_score:
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
))
return target_seq_group_metadata_list
def _create_single_target_seq_group_metadata(
self,
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
Args:
seq_group_metadata: The metadata for the input sequence.
seq_id: The input sequence ID.
target_seq_id: The corresponding target sequence ID.
token_ids: The list of token ids that are to be appended to the
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data={
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
},
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
)
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Split the target model output into speculative and non-speculative
output.
"""
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes = [
num_scoring_tokens,
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
]
(spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch(
[sampler_output])
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = (
sampler_output_to_torch([sampler_output]))
return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
This implementation increments a counter starting at 1 + max of all
provided input sequence ids.
"""
return count(start=max(seq_ids) + 1)
def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
token ids that should be scored.
Returns k+1 output lists. The additional one is used for generating the
bonus token.
Example:
Input: [0, 1, 2, 3] (k=4)
Output: (k+1 lists)
[]
[0]
[0, 1]
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([
full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))
])
return token_ids_to_score
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm.sequence import SequenceGroupMetadata
@dataclass
class SpeculativeProposals:
"""Datastructure used to represent proposal tokens from some proposer. It
also tracks how many speculative tokens each sequence has.
"""
# Speculative proposal tokens.
proposal_token_ids: torch.Tensor
# Probabilities of the proposal tokens according to the proposer.
proposal_probs: torch.Tensor
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids.shape}, "
f"proposal_probs={self.proposal_probs.shape}, "
f"proposal_lens={self.proposal_lens.shape})")
@dataclass
class SpeculativeScores:
"""Datastructure used to represent the scores of speculative tokens
according to the scoring model.
"""
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor
def __repr__(self):
return (f"SpeculativeScores("
f"probs={self.probs.shape}, "
f"token_ids={self.token_ids.shape})")
class SpeculativeProposer(ABC):
@abstractmethod
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
raise NotImplementedError
class SpeculativeScorer(ABC):
@abstractmethod
def score_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
import time
from dataclasses import dataclass
from typing import Callable, Optional
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.utils import is_pin_memory_available
@dataclass
class SpecDecodeWorkerMetrics:
"""Dataclass holding metrics emitted from the spec decode worker.
"""
# The empirical acceptance rate of the proposal method on a per-token basis.
# This is useful for evaluating how well the proposal method aligns with the
# scoring method.
draft_acceptance_rate: float
# The empirical efficiency, measured as the number of tokens emitted by the
# system divided by the number of tokens that could be emitted by the system
# if the proposal method were perfect.
system_efficiency: float
# The number of speculative tokens produced by the proposal method.
draft_tokens: int
# The number of tokens emitted by the entire system.
emitted_tokens: int
# The number of tokens accepted by the scoring model and verification
# routine, e.g. Llama2-70B and lossless rejection sampling.
#
# NOTE: Any token accepted by the verification routine is considered
# accepted (regardless of if the speculative prefix is also accepted). The
# user will usually see less accepted tokens. This metric is helpful when
# evaluating alignment of the proposal method with the scoring model.
accepted_tokens: int
# The number of speculative tokens per sequence.
num_spec_tokens: int
Timer = Callable[[], float]
class AsyncMetricsCollector:
"""Class which copies rejection sampler metrics from the device to CPU on a
non-default Torch stream.
"""
def __init__(self,
rejection_sampler: RejectionSampler,
timer: Optional[Timer] = None,
collect_interval_s: float = 5.0):
self._rejection_sampler = rejection_sampler
self._timer = time.time if timer is None else timer
self._rank: Optional[int] = None
# We don't have a device set yet.
self._copy_stream: Optional[torch.cuda.Stream] = None
self._in_flight_copy: Optional[torch.cuda.Event] = None
pin_memory = is_pin_memory_available()
self._aggregate_num_accepted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_emitted_tokens = torch.tensor(
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
self._aggregate_num_draft_tokens = 0
self._rejsample_metrics_collect_interval_s = collect_interval_s
self._last_metrics_collect_time = self._timer()
def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _should_collect_rejsample_metrics(self, now: float) -> bool:
"""Return whether or not this iteration should print rejection sampling
metrics.
"""
if self._rank != 0:
return False
if (now - self._last_metrics_collect_time <
self._rejsample_metrics_collect_interval_s):
return False
return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
CPU asynchronously.
Returns a CUDA event recording when the copy is complete.
"""
self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self._rejection_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
def _collect_rejsample_metrics(
self, k: int,
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
"""Create metrics object from statistics copied asynchronously.
Args:
k: int. The number of speculative tokens; used to determine system
efficiency.
ready_event: torch.cuda.Event. The CUDA event recording when the
async GPU->CPU copy is complete.
"""
ready_event.synchronize()
accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens
else:
draft_acceptance_rate = float("nan")
if num_possible_tokens > 0:
system_efficiency = emitted_tokens / num_possible_tokens
else:
system_efficiency = float("nan")
return SpecDecodeWorkerMetrics(
num_spec_tokens=k,
draft_acceptance_rate=draft_acceptance_rate,
system_efficiency=system_efficiency,
accepted_tokens=accepted_tokens,
draft_tokens=draft_tokens,
emitted_tokens=emitted_tokens,
)
@staticmethod
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
# Divide by k since batch size can be variable.
total_num_spec_seqs = draft_tokens / k
num_accepted_per_seq_if_all_accepted = k + 1
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
from typing import List, Dict
import copy
from typing import Dict, List, Optional, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker import Worker
......@@ -19,6 +22,21 @@ class MultiStepWorker(Worker):
requires more thought for MultiStepWorker support.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._proposer: Optional[DraftModelTop1Proposer] = None
def init_device(self):
super().init_device()
self._proposer = DraftModelTop1Proposer(
self,
self.device,
self.max_model_len,
self.vocab_size,
)
@torch.inference_mode()
def execute_model_multi_step(
self,
......@@ -58,6 +76,26 @@ class MultiStepWorker(Worker):
return model_outputs
def get_spec_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
def _append_new_tokens(
self, model_output: SamplerOutput,
seq_group_metadata_list: SequenceGroupMetadata) -> None:
......@@ -77,7 +115,7 @@ class MultiStepWorker(Worker):
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob)
seq.append_token_id(token_id, token_logprob.logprob)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
......@@ -85,21 +123,9 @@ class MultiStepWorker(Worker):
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
The multi-step worker must be able to append tokens to sequences after
a forward pass. This necessitates modification of the data structures
used by the worker. Since these data structures are shared with other
parts of vLLM, like the scheduler, we must take care not to introduce
unexpected side-effects.
When Ray is used to orchestrate worker processes (such as when the
tensor-parallel degree is >1), this is not a problem because the input
datastructures will be serialized and created anew in the worker
process.
However, when Ray is not used to orchestrate the worker processes (such
as when the tensor-parallel degree is 1), this is a problem. We avoid
the problem by shallow-copying the input datastructures (specifically,
the parts that will change in multiple steps).
Helpful when the vLLM scheduler runs in the same process as the worker.
The alternative is deep-copying (or other form of deep copy); this has
performance downsides.
"""
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
......@@ -176,3 +202,169 @@ class MultiStepWorker(Worker):
for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
class DraftModelTop1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
draft_worker: MultiStepWorker,
device: str,
max_model_len: int,
vocab_size: int,
):
self._draft_worker = draft_worker
self._device = device
self._max_model_len = max_model_len
self._vocab_size = vocab_size
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices) = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_steps=max_proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
max_proposal_len=max_proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices)
def _merge_outputs(
self,
batch_size: int,
max_proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty tensors.
proposal_tokens = torch.zeros(0,
max_proposal_len,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(0,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(size=(batch_size,
*proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens
from functools import cached_property
from typing import Dict, List, Optional, Tuple
import torch
from vllm.config import CacheConfig
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
class SpecDecodeWorker:
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
method, such as a small draft model, to speculate ahead of a larger LLM. The
probabilities of the speculative tokens are then determined by the larger
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
* Only draft-model proposal is implemented (contributions for more forms are
welcome!).
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
future work.
* Only lossless rejection sampling is supported. Contributions adding lossy
verification routines are welcome (e.g. Medusa's typical acceptance).
* All sequences in a batch must have the same proposal length, or zero. This
can be improved by having per-sequence speculation in the future.
* The scoring forward pass is done without an MQA kernel, which is
suboptimal especially as the batch size, proposal length, and sequence
lengths grow. Contributions to add a MQA scoring are welcome once
correctness tests pass.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
def __init__(
self,
proposer_worker: MultiStepWorker,
scorer_worker: Worker,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
):
"""
Create a SpecDecodeWorker.
Args:
proposer_worker: A worker that can produce speculative tokens for
sequences.
scorer_worker: A worker that produces probabilities of speculative
tokens according to some base model. Typically a vanilla vLLM
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.rejection_sampler = rejection_sampler
self._metrics = AsyncMetricsCollector(
rejection_sampler
) if metrics_collector is None else metrics_collector
self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.scorer: SpeculativeScorer = None
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_device()
self.proposer_worker.init_device()
self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
def profile_num_available_blocks(self, block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space,
cache_dtype))
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig):
"""Initialize the cache engine of the scorer and proposer workers.
"""
self.scorer_worker.init_cache_engine(cache_config)
self.proposer_worker.init_cache_engine(cache_config)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int,
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
assert seq_group_metadata_list is not None, (
"speculative decoding "
"requires non-None seq_group_metadata_list")
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return self._run_speculative_decoding_step(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_spec_tokens,
)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
"""
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
return [sampler_output]
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
sequence, then scores each speculative token using the scoring worker.
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
k,
proposals,
)
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> torch.Tensor:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
_, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
_, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices
proposal_probs = proposal_scores.probs[spec_indices, :-1]
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
accepted_token_ids = self.rejection_sampler(
proposal_probs,
bonus_token_ids,
proposals.proposal_probs,
proposals.proposal_token_ids,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()
return accepted_token_ids
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
k: int,
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
# shape: [k+1, batch_size]
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
1).tolist()
sampler_output_list = []
for token_ids_by_step in accepted_token_ids_by_step:
if all(token_id == -1 for token_id in token_ids_by_step):
break
step_output_token_ids = []
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
step_output_token_ids.append(
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
logprobs={token_id: 0.0},
)
],
prompt_logprobs=None,
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list
@cached_property
def _vocab_size(self) -> int:
"""Get the vocab size of the model and make sure it's consistent between
draft and target workers.
"""
vocab_sizes = [
worker.vocab_size
for worker in [self.proposer_worker, self.scorer_worker]
]
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
return vocab_sizes[0]
@property
def rank(self):
return self.scorer_worker.rank
@property
def device(self):
return self.scorer_worker.device
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
proposer_cache_block_size_bytes: int,
total_num_gpu_blocks: int) -> int:
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
allocate to the target model, this function calculates how many blocks
should be given to the draft and target model.
Note that usually the block size, in bytes, of each model is different,
as it's a function of number of KV/layer, number of heads, and hidden
dimension size.
Since the target and draft models allocate the same number of blocks, we
simply calculate the number of blocks where if allocated by both models,
the total memory usage from KV cache is no larger than the number of
blocks allocatable by the target model alone.
"""
new_num_gpu_blocks = int(
total_num_gpu_blocks * scorer_cache_block_size_bytes /
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
return new_num_gpu_blocks
from contextlib import contextmanager
from itertools import chain
from typing import List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
SeqId = int
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return list(
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
if select_proposal_len_zero:
predicate = lambda proposal_len: proposal_len == 0
else:
predicate = lambda proposal_len: proposal_len != 0
indices = [
i for i, (_, proposal_len
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
if predicate(proposal_len)
]
seq_groups = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
]
return seq_groups, indices
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
sampled_token_probs: torch.Tensor
shape: [batch_size, len(sampler_output_list), vocab_size]
"""
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_probs = torch.stack(
[
sampler_output.sampled_token_probs
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
[
sampler_output.sampled_token_ids.flatten()
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
return sampled_token_ids, sampled_token_probs
@contextmanager
def nvtx_range(msg, *args, **kwargs):
"""
Context manager / decorator that pushes an NVTX range at the beginning
of its scope, and pops it at the end. If extra arguments are given,
they are passed as arguments to msg.format().
If running with cuda graphs, you must enable nsys cuda graph profiling.
Arguments:
msg (string): message to associate with the range
"""
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
yield
finally:
torch.cuda.nvtx.range_pop()
......@@ -10,6 +10,7 @@ def init_test_distributed_environment(
tensor_parallel_size: int,
rank: int,
distributed_init_port: str,
local_rank: int = -1,
) -> None:
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
......@@ -18,8 +19,8 @@ def init_test_distributed_environment(
init_distributed_environment(
parallel_config,
rank,
cupy_port=None,
distributed_init_method=distributed_init_method)
distributed_init_method=distributed_init_method,
local_rank=local_rank)
def multi_process_tensor_parallel(
......
......@@ -6,10 +6,11 @@ from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"starcoder2": Starcoder2Config,
"jais": JAISConfig,
}
......@@ -17,15 +18,6 @@ def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
# FIXME(woosuk): This is a temporary fix for StarCoder2.
# Remove this when the model is supported by HuggingFace transformers.
if "bigcode" in model and "starcoder2" in model:
config_class = _CONFIG_REGISTRY["starcoder2"]
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config
try:
config = AutoConfig.from_pretrained(
model,
......@@ -49,3 +41,17 @@ def get_config(model: str,
revision=revision,
code_revision=code_revision)
return config
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
else:
return config
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
__all__ = [
"ChatGLMConfig",
"DbrxConfig",
"MPTConfig",
"RWConfig",
"Starcoder2Config",
"JAISConfig",
]
# yapf: disable
# ruff: noqa: E501
# coding=utf-8
# Copied from
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
"""Dbrx configuration."""
from typing import Any, Optional
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DbrxAttentionConfig(PretrainedConfig):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*, defaults to None):
If not `None`, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (float): The base frequency for rope.
"""
def __init__(
self,
attn_pdrop: float = 0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["attn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxFFNConfig(PretrainedConfig):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of
the activation function along with any additional keyword arguments.
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
This should only be used for benchmarking purposes.
"""
def __init__(
self,
ffn_act_fn: Optional[dict] = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1,
uniform_expert_assignment: bool = False,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["ffn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class DbrxConfig(PretrainedConfig):
"""Configuration class for Dbrx.
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 6144):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 40):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 32768):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "dbrx"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
"max_position_embeddings": "max_seq_len",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
router_aux_loss_coef: float = 0.05,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Dbrx models."
)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# coding=utf-8
# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAIS configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class JAISConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a
[`JAISModel`]. It is used to instantiate a JAIS model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used
to control the model outputs. Read the documentation from
[`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50257):
Vocabulary size of the JAIS model. Defines the number of different
tokens that can be represented by the
`inputs_ids` passed when calling [`JAISModel`].
n_positions (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used
with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
n_embd (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the
Transformer encoder.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set
it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu"`):
Activation function, to be selected in the list
`["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`.
resid_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in
the embeddings, encoder, and pooler.
embd_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the embeddings.
attn_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
scale_attn_weights (`bool`, *optional*, defaults to `True`):
Scale attention weights by dividing by sqrt(hidden_size)..
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models).
scale_attn_by_inverse_layer_idx (`bool`, *optional*,
defaults to `False`):
Whether to additionally scale attention weights by
`1 / layer_idx + 1`.
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
Whether to scale keys (K) prior to computing attention
(dot-product)
and upcast attention dot-product/softmax to float() when training
with mixed precision.
position_embedding_type (`str`, *optional*, defaults to `"learned"`):
Positional embedding can be either `"alibi"` or `"learned"`.
mup_width_scale (`float`, *optional*, defaults to 1.0):
muP parameter to scale learning rate and initializers. Calculated
as (`d_model,0 / d_model`), where
`d_model` is the model's width and `d_model,0` is the proxy
model's width.
mup_embeddings_scale (`float`, *optional*, defaults to 1.0):
muP parameter to scale token and position embeddings.
mup_output_alpha (`float`, *optional*, defaults to 1.0):
muP parameter to scale output logits
(`output_logits_scale = mup_output_alpha * mup_width_scale`).
mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`):
Scale attention weights by dividing by hidden_size instead of
sqrt(hidden_size). Need to set scale_attn_weights to `True` as
well.
alibi_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for ALiBi
embeddings. Currently only supports linear
scaling strategy. Can specify either the scaling `factor` (must be
a float greater than 1) for fixed scaling
or `train_seq_len` for dynamic scaling on input samples with
sequence length > `train_seq_len`. The expected
formats are `{"type": strategy name, "factor": scaling factor}` or
`{"type": strategy name,
"train_seq_len": training sequence length}`.
architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']):
architecture names for Jais.
Example:
```python
>>> from transformers import JAISConfig, JAISModel
>>> # Initializing a JAIS configuration
>>> configuration = JAISConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = JAISModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "jais"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "n_embd",
"max_position_embeddings": "n_positions",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
}
def __init__(
self,
vocab_size=50257,
n_positions=1024,
n_embd=768,
n_layer=12,
n_head=12,
n_inner=None,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
scale_attn_weights=True,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
scale_attn_by_inverse_layer_idx=False,
reorder_and_upcast_attn=False,
position_embedding_type="learned",
mup_width_scale=1.0,
mup_embeddings_scale=1.0,
mup_output_alpha=1.0,
mup_scale_qk_dot_by_d=False,
alibi_scaling=None,
architectures=None,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_inner = n_inner
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.reorder_and_upcast_attn = reorder_and_upcast_attn
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.position_embedding_type = position_embedding_type
self.mup_width_scale = mup_width_scale
self.mup_embeddings_scale = mup_embeddings_scale
self.mup_output_alpha = mup_output_alpha
self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d
self.alibi_scaling = alibi_scaling
self._alibi_scaling_validation()
if architectures is None:
architectures = ["JAISLMHeadModel"]
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
architectures=architectures,
**kwargs,
)
def _alibi_scaling_validation(self):
"""
Validate the `alibi_scaling` configuration.
"""
if self.alibi_scaling is None:
return
if (not isinstance(self.alibi_scaling, dict)
or len(self.alibi_scaling) != 2):
raise ValueError(
"`alibi_scaling` must be a dictionary with two fields,"
"`type` and `factor` or `type` and `train_seq_len`, "
f"got {self.alibi_scaling}")
alibi_scaling_type = self.alibi_scaling.get("type", None)
alibi_scaling_factor = self.alibi_scaling.get("factor", None)
alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None)
if alibi_scaling_type is None or alibi_scaling_type != "linear":
raise ValueError(f"`alibi_scaling`'s type field must be 'linear',"
f"got {alibi_scaling_type}")
if (alibi_scaling_factor is not None
and not isinstance(alibi_scaling_factor, float)
or alibi_scaling_factor <= 1.0):
raise ValueError(
f"`alibi_scaling`'s factor field must be a float > 1.0,"
f"got {alibi_scaling_factor}")
if (alibi_dynamic_scaling is not None
and not isinstance(alibi_dynamic_scaling, int)
or alibi_dynamic_scaling <= 1):
raise ValueError(
f"`alibi_scaling`'s `train_seq_len` field must be an"
f"integer > 1, got {alibi_dynamic_scaling}")
......@@ -4,6 +4,7 @@
"""A HuggingFace-style model configuration."""
import warnings
from typing import Any, Dict, Optional, Union
from transformers import PretrainedConfig
attn_config_defaults: Dict = {
......@@ -62,62 +63,6 @@ class MPTConfig(PretrainedConfig):
fc_type: str = 'torch',
verbose: Optional[int] = None,
**kwargs: Any):
"""The MPT configuration class.
Args:
d_model (int): The size of the embedding dimension of the model.
n_heads (int): The number of attention heads.
n_layers (int): The number of layers in the model.
expansion_ratio (int): The ratio of the up/down scale in the ffn.
max_seq_len (int): The maximum sequence length of the model.
vocab_size (int): The size of the vocabulary.
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
emb_pdrop (float): The dropout probability for the embedding layer.
learned_pos_emb (bool): Whether to use learned positional embeddings
attn_config (Dict): A dictionary used to configure the model's attention module:
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
attn_pdrop (float): The dropout probability for the attention layers.
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
use the default scale of ``1/sqrt(d_keys)``.
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
verbose (int): The verbosity level. 0 is silent.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
norm_type (str): choose type of norm to use
use_cache (bool): Whether or not the model should return the last key/values attentions
init_config (Dict): A dictionary used to configure the model initialization:
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
init_std (float): The standard deviation of the normal distribution used to initialize the model,
if using the baseline_ parameter initialization scheme.
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
"""
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
......@@ -139,8 +84,8 @@ class MPTConfig(PretrainedConfig):
self.fc_type = fc_type
if verbose is not None:
warnings.warn(DeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
),
'verbose argument for MPTConfig is now ignored and '
'will be removed. Use python_log_level instead.'),
stacklevel=2)
if 'name' in kwargs:
del kwargs['name']
......@@ -149,7 +94,8 @@ class MPTConfig(PretrainedConfig):
if self.attn_config.get('alibi', False):
self.learned_pos_emb = False
warnings.warn(
f'alibi is turned on, setting `learned_pos_emb` to {self.learned_pos_emb}`',
f'alibi is turned on, setting `learned_pos_emb` '
f'to {self.learned_pos_emb}`',
stacklevel=2)
super().__init__(**kwargs)
self._validate_config()
......@@ -176,8 +122,8 @@ class MPTConfig(PretrainedConfig):
[self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]
)):
raise ValueError(
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" # pylint: disable=line-too-long
)
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are "
"probabilities and must be between 0 and 1")
if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
raise ValueError(
f"Unknown attn_impl={self.attn_config['attn_impl']}")
......@@ -193,17 +139,17 @@ class MPTConfig(PretrainedConfig):
if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.' # pylint: disable=line-too-long
)
'attn_uses_sequence_id only implemented with torch '
'and triton attention.')
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' # pylint: disable=line-too-long
)
'model.embedding_fraction must be between 0 (exclusive) '
'and 1 (inclusive)!')
if isinstance(self.logit_scale,
str) and self.logit_scale != 'inv_sqrt_d_model':
raise ValueError(
f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." # pylint: disable=line-too-long
)
f"self.logit_scale={self.logit_scale!r} is not recognized as "
"an option; use numeric value or 'inv_sqrt_d_model'.")
if self.init_config.get('name', None) is None:
raise ValueError(
f"self.init_config={self.init_config!r} 'name' needs to be set."
......@@ -219,11 +165,11 @@ class MPTConfig(PretrainedConfig):
del te
except Exception as exc:
raise ImportError(
# pylint: disable=line-too-long
'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. '
+
'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n'
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'TransformerEngine import fail. `fc_type: te` requires '
'TransformerEngine be installed. '
'The required version of transformer_engine also requires '
'FlashAttention v1.0.6 is installed:\n'
'pip install flash-attn==1.0.6 --no-build-isolation \n'
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
) from exc
if self.ffn_config['ffn_type'] == 'mptmlp':
......
from transformers import PretrainedConfig
class Starcoder2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Starcoder2Model`]. It is used to instantiate a
Starcoder2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the [bigcode/starcoder2-7b_16k](https://huggingface.co/bigcode/starcoder2-7b_16k) model.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 49152):
Vocabulary size of the Starcoder2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Starcoder2Model`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 12288):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 30):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 24):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with. Starcoder2's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_epsilon (`float`, *optional*, defaults to 1e-05):
Epsilon value for the layer norm
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
bos_token_id (`int`, *optional*, defaults to 50256):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 50256):
The id of the "end-of-sequence" token.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `None` (no sliding window).
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
residual_dropout (`float`, *optional*, defaults to 0.0):
Residual connection dropout value.
embedding_dropout (`float`, *optional*, defaults to 0.0):
Embedding dropout.
use_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias term on linear layers of the model.
```python
>>> from transformers import Starcoder2Model, Starcoder2Config
>>> # Initializing a Starcoder2 7B style configuration
>>> configuration = Starcoder2Config()
>>> # Initializing a model from the Starcoder2 7B style configuration
>>> model = Starcoder2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "starcoder2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=49152,
hidden_size=3072,
intermediate_size=12288,
num_hidden_layers=30,
num_attention_heads=24,
num_key_value_heads=2,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=4096,
initializer_range=0.018042,
norm_epsilon=1e-5,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
rope_theta=10000.0,
sliding_window=None,
attention_dropout=0.0,
residual_dropout=0.0,
embedding_dropout=0.0,
use_bias=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.use_bias = use_bias
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_epsilon = norm_epsilon
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.embedding_dropout = embedding_dropout
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
if self.architectures is None:
self.architectures = ['Starcoder2ForCausalLM']
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.tokenizer import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
# Used eg. for marking rejected tokens in spec decoding.
INVALID_TOKEN_ID = -1
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: BaseTokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def decode_prompt_logprobs_inplace(
self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
"""Decodes the logprobs for the prompt of a sequence group.
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
Returns:
The prompt logprobs with the decoded tokens.
"""
prms = seq_group.sampling_params
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens = []
prev_tokens = None
for token_position, prompt_logprobs_for_token in enumerate(
prompt_logprobs):
if not prompt_logprobs_for_token:
continue
for token_id, sample_logprob in prompt_logprobs_for_token.items():
if (sample_logprob.decoded_token is None
and token_id != INVALID_TOKEN_ID):
prompt_token_ids_with_token = (
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
# Use the offsets & prev tokens corresponding to
# real tokens to ensure detokenization is consistent
# actual with prompt.
if token_id == all_token_ids[token_position]:
next_iter_prefix_offset = new_prefix_offset
next_iter_read_offset = new_read_offset
next_iter_tokens = new_tokens
# Advance to the next token position.
prefix_offset = next_iter_prefix_offset
read_offset = next_iter_read_offset
if prev_tokens is None:
prev_tokens = next_iter_tokens
else:
prev_tokens.extend(next_iter_tokens)
def decode_sequence_inplace(self, seq: Sequence,
prms: SamplingParams) -> None:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
......@@ -5,12 +5,52 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import *
from vllm.utils import make_async
logger = init_logger(__name__)
def get_cached_tokenizer(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
function caches these properties for faster access."""
tokenizer_all_special_ids = set(tokenizer.all_special_ids)
tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
class CachedTokenizer(tokenizer.__class__):
@property
def all_special_ids(self):
return tokenizer_all_special_ids
@property
def all_special_tokens(self):
return tokenizer_all_special_tokens
@property
def all_special_tokens_extended(self):
return tokenizer_all_special_tokens_extended
def __len__(self):
return tokenizer_len
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
tokenizer.__class__ = CachedTokenizer
return tokenizer
def get_tokenizer(
tokenizer_name: str,
*args,
......@@ -64,7 +104,7 @@ def get_tokenizer(
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return tokenizer
return get_cached_tokenizer(tokenizer)
def get_lora_tokenizer(lora_request: LoRARequest, *args,
......@@ -88,63 +128,6 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
......@@ -179,6 +162,34 @@ def _convert_tokens_to_string_with_added_encoders(
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# Offset a little more in case we have special tokens.
prefix_offset = max(
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
# We do not need to convert the whole prompt to tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
prefix_offset = max(
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
read_offset = len(new_tokens)
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
......@@ -186,31 +197,57 @@ def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int = 0,
read_offset: int = 0,
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
if prev_tokens is None:
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids, skip_special_tokens=skip_special_tokens)
output_tokens = new_tokens
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token.
prefix_offset = max(len(output_tokens) - 6, 0)
# If the first new token is a special token, we can't skip 1 extra token
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
read_offset = max(len(output_tokens), 0)
else:
read_offset = max(len(output_tokens) - 1, 0)
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
# If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer):
new_tokens = [""]
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
......
from typing import Optional
from vllm.config import TokenizerPoolConfig
from vllm.engine.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
else:
RayTokenizerGroupPool = None
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
if tokenizer_pool_config is None:
return TokenizerGroup(**init_kwargs)
if tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None:
raise ImportError(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.")
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
**init_kwargs)
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment