Unverified Commit d6fa1be3 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Quality] Add code formatter and linter (#326)

parent 0ffded81
......@@ -55,6 +55,7 @@ class RequestOutput:
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
"""
def __init__(
self,
request_id: str,
......@@ -75,8 +76,9 @@ class RequestOutput:
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
assert n <= len(seqs)
sorted_seqs = sorted(
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
sorted_seqs = sorted(seqs,
key=lambda seq: seq.get_cumulative_logprob(),
reverse=True)
top_n_seqs = sorted_seqs[:n]
# Create the outputs.
......
......@@ -3,6 +3,7 @@ from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams:
"""Sampling parameters for text generation.
......@@ -51,7 +52,7 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
stop: Union[str, List[str]] = [],
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
......@@ -64,7 +65,12 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search
self.stop = [stop] if isinstance(stop, str) else list(stop)
if stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = list(stop)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
......
"""Sequence and its related classes."""
import copy
import enum
from typing import Dict, List, Optional, Union
......@@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
class SequenceStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto()
RUNNING = enum.auto()
SWAPPED = enum.auto()
......@@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED
SequenceStatus.FINISHED_IGNORED,
]
@staticmethod
......@@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
class SequenceData:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__(
self,
......@@ -75,6 +88,15 @@ class SequenceData:
class Sequence:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
"""
def __init__(
self,
......@@ -149,19 +171,27 @@ class Sequence:
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
def fork(self, child_seq: "Sequence") -> None:
child_seq.logical_token_blocks = copy.deepcopy(
self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data)
return None
def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, '
f'status={self.status.name}, '
f'num_blocks={len(self.logical_token_blocks)})')
return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, "
f"num_blocks={len(self.logical_token_blocks)})")
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def __init__(
self,
......@@ -191,7 +221,7 @@ class SequenceGroup:
for seq in self.seqs:
if seq.seq_id == seq_id:
return seq
raise ValueError(f'Sequence {seq_id} not found.')
raise ValueError(f"Sequence {seq_id} not found.")
def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs)
......@@ -203,14 +233,25 @@ class SequenceGroup:
class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
"""
def __init__(
self,
request_id: str,
is_prompt: bool,
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
block_tables: Dict[int, List[int]],
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
......@@ -220,13 +261,23 @@ class SequenceGroupMetadata:
class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
logprobs: Dict[int, float],
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
......@@ -234,15 +285,15 @@ class SequenceOutputs:
self.logprobs = logprobs
def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, '
f'parent_seq_id={self.parent_seq_id}, '
f'output_token={self.output_token}), '
f'logprobs={self.logprobs}')
return (f"SequenceOutputs(seq_id={self.seq_id}, "
f"parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}), "
f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
return NotImplemented
return (self.seq_id == other.seq_id and
self.parent_seq_id == other.parent_seq_id and
self.output_token == other.output_token and
self.logprobs == other.logprobs)
return (self.seq_id == other.seq_id
and self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
......@@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
tokenizer_mode: str = "auto",
*args,
tokenizer_mode: str = "auto",
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
......@@ -73,7 +73,8 @@ def detokenize_incrementally(
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
......
......@@ -17,9 +17,9 @@ class Counter:
self.counter = start
def __next__(self) -> int:
id = self.counter
i = self.counter
self.counter += 1
return id
return i
def reset(self) -> None:
self.counter = 0
......@@ -38,6 +38,7 @@ def get_cpu_memory() -> int:
def random_uuid() -> str:
return str(uuid.uuid4().hex)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()
......@@ -93,8 +93,8 @@ class CacheEngine:
if not pin_memory:
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warn("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape),
......@@ -120,11 +120,10 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(
src_key_cache, dst_key_cache, src_to_dst)
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(
src_value_cache, dst_value_cache, src_to_dst)
cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
......
......@@ -73,8 +73,8 @@ class Worker:
# number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99,
top_k=self.model.config.vocab_size - 1)
vocab_size = self.model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
seqs = []
......@@ -91,7 +91,8 @@ class Worker:
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seqs)
# Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config)
......@@ -110,8 +111,9 @@ class Worker:
total_gpu_memory = get_gpu_memory()
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
- peak_memory) // cache_block_size)
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
......@@ -125,8 +127,8 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.block_size = cache_config.block_size
self.cache_engine = CacheEngine(
self.cache_config, self.model_config, self.parallel_config)
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
......@@ -202,8 +204,8 @@ class Worker:
generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
len(block_table))
context_lens.append(context_len)
block_number = block_table[position // self.block_size]
......@@ -223,7 +225,8 @@ class Worker:
context_lens_tensor = torch.cuda.IntTensor(context_lens)
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
for block_table in generation_block_tables
]
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
seq_data: Dict[int, SequenceData] = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment