Unverified Commit d6fa1be3 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Quality] Add code formatter and linter (#326)

parent 0ffded81
...@@ -55,6 +55,7 @@ class RequestOutput: ...@@ -55,6 +55,7 @@ class RequestOutput:
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
""" """
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
...@@ -75,8 +76,9 @@ class RequestOutput: ...@@ -75,8 +76,9 @@ class RequestOutput:
n = seq_group.sampling_params.n n = seq_group.sampling_params.n
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
assert n <= len(seqs) assert n <= len(seqs)
sorted_seqs = sorted( sorted_seqs = sorted(seqs,
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True) key=lambda seq: seq.get_cumulative_logprob(),
reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]
# Create the outputs. # Create the outputs.
......
...@@ -3,6 +3,7 @@ from typing import List, Optional, Union ...@@ -3,6 +3,7 @@ from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
...@@ -51,7 +52,7 @@ class SamplingParams: ...@@ -51,7 +52,7 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
use_beam_search: bool = False, use_beam_search: bool = False,
stop: Union[str, List[str]] = [], stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
...@@ -64,7 +65,12 @@ class SamplingParams: ...@@ -64,7 +65,12 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.stop = [stop] if isinstance(stop, str) else list(stop) if stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = list(stop)
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.logprobs = logprobs self.logprobs = logprobs
......
"""Sequence and its related classes."""
import copy import copy
import enum import enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams ...@@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto() WAITING = enum.auto()
RUNNING = enum.auto() RUNNING = enum.auto()
SWAPPED = enum.auto() SWAPPED = enum.auto()
...@@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum): ...@@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED, SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED SequenceStatus.FINISHED_IGNORED,
] ]
@staticmethod @staticmethod
...@@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum): ...@@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
class SequenceData: class SequenceData:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__( def __init__(
self, self,
...@@ -75,6 +88,15 @@ class SequenceData: ...@@ -75,6 +88,15 @@ class SequenceData:
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
"""
def __init__( def __init__(
self, self,
...@@ -149,19 +171,27 @@ class Sequence: ...@@ -149,19 +171,27 @@ class Sequence:
def is_finished(self) -> bool: def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status) return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None: def fork(self, child_seq: "Sequence") -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.logical_token_blocks = copy.deepcopy(
self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data) child_seq.data = copy.deepcopy(self.data)
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, ' return (f"Sequence(seq_id={self.seq_id}, "
f'status={self.status.name}, ' f"status={self.status.name}, "
f'num_blocks={len(self.logical_token_blocks)})') f"num_blocks={len(self.logical_token_blocks)})")
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def __init__( def __init__(
self, self,
...@@ -191,7 +221,7 @@ class SequenceGroup: ...@@ -191,7 +221,7 @@ class SequenceGroup:
for seq in self.seqs: for seq in self.seqs:
if seq.seq_id == seq_id: if seq.seq_id == seq_id:
return seq return seq
raise ValueError(f'Sequence {seq_id} not found.') raise ValueError(f"Sequence {seq_id} not found.")
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs) return all(seq.is_finished() for seq in self.seqs)
...@@ -203,14 +233,25 @@ class SequenceGroup: ...@@ -203,14 +233,25 @@ class SequenceGroup:
class SequenceGroupMetadata: class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
"""
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
is_prompt: bool, is_prompt: bool,
seq_data: Dict[int, SequenceData], # Seq id -> sequence data. seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers. block_tables: Dict[int, List[int]],
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
...@@ -220,13 +261,23 @@ class SequenceGroupMetadata: ...@@ -220,13 +261,23 @@ class SequenceGroupMetadata:
class SequenceOutputs: class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
parent_seq_id: int, parent_seq_id: int,
output_token: int, output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i). logprobs: Dict[int, float],
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.parent_seq_id = parent_seq_id self.parent_seq_id = parent_seq_id
...@@ -234,15 +285,15 @@ class SequenceOutputs: ...@@ -234,15 +285,15 @@ class SequenceOutputs:
self.logprobs = logprobs self.logprobs = logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, ' return (f"SequenceOutputs(seq_id={self.seq_id}, "
f'parent_seq_id={self.parent_seq_id}, ' f"parent_seq_id={self.parent_seq_id}, "
f'output_token={self.output_token}), ' f"output_token={self.output_token}), "
f'logprobs={self.logprobs}') f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutputs):
return NotImplemented return NotImplemented
return (self.seq_id == other.seq_id and return (self.seq_id == other.seq_id
self.parent_seq_id == other.parent_seq_id and and self.parent_seq_id == other.parent_seq_id
self.output_token == other.output_token and and self.output_token == other.output_token
self.logprobs == other.logprobs) and self.logprobs == other.logprobs)
...@@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" ...@@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer( def get_tokenizer(
tokenizer_name: str, tokenizer_name: str,
tokenizer_mode: str = "auto",
*args, *args,
tokenizer_mode: str = "auto",
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
...@@ -73,7 +73,8 @@ def detokenize_incrementally( ...@@ -73,7 +73,8 @@ def detokenize_incrementally(
output_text = tokenizer.convert_tokens_to_string(output_tokens) output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text return new_token, output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over # NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow # the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple. # even when the loop body is very simple.
......
...@@ -17,9 +17,9 @@ class Counter: ...@@ -17,9 +17,9 @@ class Counter:
self.counter = start self.counter = start
def __next__(self) -> int: def __next__(self) -> int:
id = self.counter i = self.counter
self.counter += 1 self.counter += 1
return id return i
def reset(self) -> None: def reset(self) -> None:
self.counter = 0 self.counter = 0
...@@ -38,6 +38,7 @@ def get_cpu_memory() -> int: ...@@ -38,6 +38,7 @@ def get_cpu_memory() -> int:
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
def in_wsl() -> bool: def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071 # Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
...@@ -93,8 +93,8 @@ class CacheEngine: ...@@ -93,8 +93,8 @@ class CacheEngine:
if not pin_memory: if not pin_memory:
# Pinning memory in WSL is not supported. # Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warn("Using 'pin_memory=False' as WSL is detected. " logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.") "This may slow down the performance.")
for _ in range(self.num_layers): for _ in range(self.num_layers):
key_blocks = torch.empty( key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape), size=(self.num_cpu_blocks, *key_block_shape),
...@@ -120,11 +120,10 @@ class CacheEngine: ...@@ -120,11 +120,10 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i] src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i] dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks. # Copy the key blocks.
cache_ops.swap_blocks( cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks. # Copy the value blocks.
cache_ops.swap_blocks( cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_value_cache, dst_value_cache, src_to_dst) src_to_dst)
event = self.events[i] event = self.events[i]
event.record(stream=self.cache_stream) event.record(stream=self.cache_stream)
......
...@@ -73,8 +73,8 @@ class Worker: ...@@ -73,8 +73,8 @@ class Worker:
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage. # Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, vocab_size = self.model.config.vocab_size
top_k=self.model.config.vocab_size - 1) sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
seqs = [] seqs = []
...@@ -91,7 +91,8 @@ class Worker: ...@@ -91,7 +91,8 @@ class Worker:
) )
seqs.append(seq) seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs) input_tokens, input_positions, input_metadata = self._prepare_inputs(
seqs)
# Execute the model. # Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
...@@ -110,8 +111,9 @@ class Worker: ...@@ -110,8 +111,9 @@ class Worker:
total_gpu_memory = get_gpu_memory() total_gpu_memory = get_gpu_memory()
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config) block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization num_gpu_blocks = int(
- peak_memory) // cache_block_size) (total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
...@@ -125,8 +127,8 @@ class Worker: ...@@ -125,8 +127,8 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None: def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config self.cache_config = cache_config
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.cache_engine = CacheEngine( self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.cache_config, self.model_config, self.parallel_config) self.parallel_config)
self.cache_events = self.cache_engine.events self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache self.gpu_cache = self.cache_engine.gpu_cache
...@@ -202,8 +204,8 @@ class Worker: ...@@ -202,8 +204,8 @@ class Worker:
generation_block_tables.append(block_table) generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len) max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max( max_num_blocks_per_seq = max(max_num_blocks_per_seq,
max_num_blocks_per_seq, len(block_table)) len(block_table))
context_lens.append(context_len) context_lens.append(context_len)
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
...@@ -223,7 +225,8 @@ class Worker: ...@@ -223,7 +225,8 @@ class Worker:
context_lens_tensor = torch.cuda.IntTensor(context_lens) context_lens_tensor = torch.cuda.IntTensor(context_lens)
padded_block_tables = [ padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq) _pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables] for block_table in generation_block_tables
]
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment