Unverified Commit a0086298 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Misc] Various simplifications and typing fixes (#5368)

parent 76477a93
...@@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group, we can take the # Since there's only one sequence per sequence group, we can take the
# first sample. # first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))] samples = [output.samples[0] for output in outputs]
# -1 means the output token is not valid (eg. due to spec decode # -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens). # rejecting tokens).
......
...@@ -306,8 +306,10 @@ class RejectionSampler(nn.Module): ...@@ -306,8 +306,10 @@ class RejectionSampler(nn.Module):
# Fill in the first k columns of the output tensor using masks and data # Fill in the first k columns of the output tensor using masks and data
# tensors. # tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids, torch.where(accepted_mask,
-torch.ones_like(draft_token_ids)) draft_token_ids,
-torch.ones_like(draft_token_ids),
out=output)
# Fill the last column. # Fill the last column.
# We check output directly as accepted may have True values inconsistent # We check output directly as accepted may have True values inconsistent
......
...@@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output = self._scorer_worker.execute_model( target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone( execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, )) seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output" assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
...@@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens) num_scoring_tokens)
def _contract_batch( def _contract_batch(
self, contracted_bs: int, self, contracted_bs: int, target_sampler_output: SamplerOutput,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals, num_scoring_tokens: int, proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int], non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape( target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
spec_expanded_bs, k + 1) target_probs = target_probs.reshape(*target_token_ids.shape,
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size)
self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size) all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_tokens = torch.full(size=(contracted_bs, k + 1), all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
fill_value=-1, all_logprobs = target_logprobs.new_full(size=all_probs.shape,
device=self._device, fill_value=-float("inf"))
dtype=torch.long)
all_probs = torch.zeros(contracted_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
if non_spec_indices: if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
......
...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
...@@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
""" """
assert "speculative_config" in kwargs assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config") speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None assert speculative_config is not None
target_worker = Worker(*args, **kwargs) target_worker = Worker(*args, **kwargs)
...@@ -109,12 +110,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -109,12 +110,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with proposer=%s", logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker)) type(proposer_worker))
return SpecDecodeWorker( return SpecDecodeWorker(proposer_worker,
proposer_worker, scorer_worker,
scorer_worker, disable_by_batch_size=disable_by_batch_size,
disable_by_batch_size=disable_by_batch_size, rejection_sampler=RejectionSampler(
rejection_sampler=RejectionSampler( disable_bonus_tokens=disable_bonus_tokens))
disable_bonus_tokens=disable_bonus_tokens, ))
def __init__( def __init__(
self, self,
......
...@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer): ...@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices, nonzero_proposal_len_indices,
) )
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, @staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed): nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset """Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal their proposal_len to 0 the draft worker does not provide a proposal
...@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer): ...@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
self, self,
batch_size: int, batch_size: int,
proposal_len: int, proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput], maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int], proposal_lens: List[int],
nonzero_proposal_len_indices: List[int], nonzero_proposal_len_indices: List[int],
sampler_transposed: bool, sampler_transposed: bool,
...@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer): ...@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
if maybe_sampler_output is None: if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None. # If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals. # In this case we return empty proposals.
proposal_tokens = torch.full( proposal_tokens = torch.tensor(-1,
size=( dtype=torch.long,
batch_size, device=self._device).expand(
proposal_len, batch_size, proposal_len)
), proposal_probs = torch.tensor(0,
fill_value=-1, dtype=torch.float32,
dtype=torch.long, device=self._device).expand(
device=self._device, batch_size, proposal_len,
) self._vocab_size)
proposal_probs = torch.zeros( proposal_lens_tensor = torch.tensor(0,
batch_size, dtype=torch.long,
proposal_len, device=self._device).expand(
self._vocab_size, len(proposal_lens))
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output sampler_output = maybe_sampler_output
...@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer): ...@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
# Now, reformat the output GPU tensors such that each sequence has # Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1] # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full( entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]), size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1, fill_value=-1,
dtype=torch.long,
device=self._device,
) )
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros( entire_proposal_probs = proposal_probs.new_zeros(
batch_size, batch_size,
*proposal_probs.shape[1:], *proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
) )
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
......
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput) SequenceOutput)
SeqId = int SeqId = int
...@@ -16,11 +15,7 @@ def get_all_seq_ids( ...@@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all """Given a list of SequenceGroupMetadata, create a list of all
sequence ids. sequence ids.
""" """
return list( return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
def get_all_num_logprobs( def get_all_num_logprobs(
...@@ -68,7 +63,7 @@ def create_sequence_group_output( ...@@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id: SeqId, seq_id: SeqId,
topk_token_ids: List[int], topk_token_ids: List[int],
topk_logprobs: List[float], topk_logprobs: List[float],
) -> SequenceGroupOutput: ) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results. """Create a SequenceGroupOutput given the sampling results.
Args: Args:
......
from typing import Dict, Optional from typing import Dict, Optional, Type
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -9,7 +9,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, ...@@ -9,7 +9,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
logger = init_logger(__name__) logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
...@@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads") assert hasattr(config.text_config, "num_attention_heads")
return config.text_config return config.text_config
else: else:
return config return config
\ No newline at end of file
...@@ -527,16 +527,6 @@ class ModelRunner: ...@@ -527,16 +527,6 @@ class ModelRunner:
) )
assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
...@@ -544,11 +534,6 @@ class ModelRunner: ...@@ -544,11 +534,6 @@ class ModelRunner:
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
...@@ -601,6 +586,21 @@ class ModelRunner: ...@@ -601,6 +586,21 @@ class ModelRunner:
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype) data_type=kv_cache_dtype)
else: else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment