Unverified Commit d6fa1be3 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Quality] Add code formatter and linter (#326)

parent 0ffded81
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
logging-fstring-interpolation, # added by vLLM
logging-not-lazy, # added by vLLM
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-class-docstring, # TODO (vLLM): enable
missing-function-docstring,
missing-module-docstring, # TODO (vLLM): enable
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
unspecified-encoding,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException
...@@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi ...@@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
We include a formatting script [`format.sh`](./format.sh) to format the code.
### Pull Requests ### Pull Requests
When submitting a pull request: When submitting a pull request:
1. Make sure your code has been rebased on top of the latest commit on the main branch. 1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Include a detailed description of the changes in the pull request. 2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
3. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did. Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description. If your pull request fixes an open issue, please include a reference to it in the description.
......
...@@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None: ...@@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str, api_url: str, n: int = 1, def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response: stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
...@@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1, ...@@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"): delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
......
...@@ -12,9 +12,14 @@ def http_bot(prompt): ...@@ -12,9 +12,14 @@ def http_bot(prompt):
"stream": True, "stream": True,
"max_tokens": 128, "max_tokens": 128,
} }
response = requests.post(args.model_url, headers=headers, json=pload, stream=True) response = requests.post(args.model_url,
headers=headers,
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"][0] output = data["text"][0]
...@@ -23,11 +28,11 @@ def http_bot(prompt): ...@@ -23,11 +28,11 @@ def http_bot(prompt):
def build_demo(): def build_demo():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown( gr.Markdown("# vLLM text completion demo\n")
"# vLLM text completion demo\n" inputbox = gr.Textbox(label="Input",
) placeholder="Enter text and press ENTER")
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") outputbox = gr.Textbox(label="Output",
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox]) inputbox.submit(http_bot, [inputbox], [outputbox])
return demo return demo
...@@ -36,7 +41,9 @@ if __name__ == "__main__": ...@@ -36,7 +41,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate") parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
args = parser.parse_args() args = parser.parse_args()
demo = build_demo() demo = build_demo()
......
...@@ -14,9 +14,14 @@ def main(args: argparse.Namespace): ...@@ -14,9 +14,14 @@ def main(args: argparse.Namespace):
("To be or not to be,", ("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?", ("What is the meaning of life?",
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)), SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly", ("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
] ]
# Run the engine by calling `engine.step()` manually. # Run the engine by calling `engine.step()` manually.
......
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
......
...@@ -12,8 +12,13 @@ print("Models:", models) ...@@ -12,8 +12,13 @@ print("Models:", models)
# Test completion API # Test completion API
stream = True stream = True
completion = openai.Completion.create( completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2, model=model,
best_of=3, stream=stream, logprobs=3) prompt="A robot may not injure a human being",
echo=False,
n=2,
best_of=3,
stream=stream,
logprobs=3)
# print the completion # print the completion
if stream: if stream:
......
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi
mypy # formatting
yapf==0.32.0
pylint==2.8.2
# type checking
mypy==0.991
types-PyYAML
types-requests
types-setuptools
# testing
pytest pytest
...@@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention( ...@@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention(
keys = torch.stack(keys, dim=0) keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values, dim=0)
scale = 1.0 / (head_size ** 0.5) scale = 1.0 / (head_size**0.5)
out = ref_masked_attention(q, keys, values, scale) out = ref_masked_attention(q, keys, values, scale)
out = out.view(num_heads, head_size) out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True) output[i].copy_(out, non_blocking=True)
...@@ -74,7 +74,7 @@ def ref_multi_query_kv_attention( ...@@ -74,7 +74,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
head_size = query.shape[-1] head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5) scale = 1.0 / (head_size**0.5)
num_seqs = len(cu_seq_lens) - 1 num_seqs = len(cu_seq_lens) - 1
ref_outputs = [] ref_outputs = []
...@@ -84,8 +84,8 @@ def ref_multi_query_kv_attention( ...@@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
seq_len = end_idx - start_idx seq_len = end_idx - start_idx
# Create attention mask. # Create attention mask.
attn_mask = torch.triu( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device='cuda')
...@@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention( ...@@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention(
num_heads = value_cache.shape[1] num_heads = value_cache.shape[1]
head_size = value_cache.shape[2] head_size = value_cache.shape[2]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
scale = 1.0 / (head_size ** 0.5) scale = 1.0 / (head_size**0.5)
num_queries = len(cu_query_lens) - 1 num_queries = len(cu_query_lens) - 1
ref_outputs = [] ref_outputs = []
...@@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention( ...@@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
block_table = block_tables[i] block_table = block_tables[i]
# Create attention mask # Create attention mask
attn_mask = torch.triu( attn_mask = torch.triu(torch.ones(query_len, context_len),
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5 diagonal=context_len - query_len + 1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device='cuda')
keys = [] keys = []
...@@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention( ...@@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention(
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
qkv = torch.empty( qkv = torch.empty(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3) qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1) query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x) key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty( key_cache = torch.empty(size=(num_blocks, *key_block_shape),
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
key_cache.uniform_(-1e-3, 1e-3) key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size) value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty( value_cache = torch.empty(size=(num_blocks, *value_block_shape),
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
value_cache.uniform_(-1e-3, 1e-3) value_cache.uniform_(-1e-3, 1e-3)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens) max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
...@@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention( ...@@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention(
block_tables.append(block_table) block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5)) scale = float(1.0 / (head_size**0.5))
output = torch.empty( output = torch.empty(num_tokens,
num_tokens, num_heads, head_size, dtype=dtype, device='cuda') num_heads,
head_size,
dtype=dtype,
device='cuda')
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
query, query,
...@@ -235,9 +244,13 @@ def run_multi_query_kv_attention( ...@@ -235,9 +244,13 @@ def run_multi_query_kv_attention(
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
num_tokens = sum(seq_lens) num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size ** 0.5)) scale = float(1.0 / (head_size**0.5))
qkv = torch.empty( qkv = torch.empty(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3) qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1) query, key, value = qkv.unbind(dim=1)
......
...@@ -26,8 +26,9 @@ def run_copy_blocks( ...@@ -26,8 +26,9 @@ def run_copy_blocks(
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = [] key_caches = []
for _ in range(num_layers): for _ in range(num_layers):
key_cache = torch.randn( key_cache = torch.randn(size=key_cache_shape,
size=key_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
key_caches.append(key_cache) key_caches.append(key_cache)
cloned_key_caches = [] cloned_key_caches = []
for key_cache in key_caches: for key_cache in key_caches:
...@@ -36,8 +37,9 @@ def run_copy_blocks( ...@@ -36,8 +37,9 @@ def run_copy_blocks(
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = [] value_caches = []
for _ in range(num_layers): for _ in range(num_layers):
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
value_caches.append(value_cache) value_caches.append(value_cache)
cloned_value_caches = [] cloned_value_caches = []
for value_cache in value_caches: for value_cache in value_caches:
...@@ -49,15 +51,18 @@ def run_copy_blocks( ...@@ -49,15 +51,18 @@ def run_copy_blocks(
# Reference implementation. # Reference implementation.
for src, dsts in block_mapping.items(): for src, dsts in block_mapping.items():
for dst in dsts: for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches,
cloned_key_caches):
cloned_key_cache[dst] = cloned_key_cache[src] cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
cloned_value_cache[dst] = cloned_value_cache[src] cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results. # Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
...@@ -74,8 +79,12 @@ def run_reshape_and_cache( ...@@ -74,8 +79,12 @@ def run_reshape_and_cache(
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn( qkv = torch.randn(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
...@@ -84,15 +93,19 @@ def run_reshape_and_cache( ...@@ -84,15 +93,19 @@ def run_reshape_and_cache(
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache.clone()
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping)
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
...@@ -114,8 +127,12 @@ def run_gather_cached_kv( ...@@ -114,8 +127,12 @@ def run_gather_cached_kv(
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn( qkv = torch.randn(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
qkv_clone = qkv.clone() qkv_clone = qkv.clone()
...@@ -126,15 +143,20 @@ def run_gather_cached_kv( ...@@ -126,15 +143,20 @@ def run_gather_cached_kv(
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
slot_mapping)
# Reference implementation. # Reference implementation.
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x) reshaped_key = cloned_key.reshape(num_tokens, num_heads,
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') head_size // x, x)
block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size block_offset = slot_mapping[i] % block_size
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
cloned_value[i] = value_cache[block_idx, :, :, block_offset] cloned_value[i] = value_cache[block_idx, :, :, block_offset]
...@@ -145,20 +167,30 @@ def run_gather_cached_kv( ...@@ -145,20 +167,30 @@ def run_gather_cached_kv(
def test_copy_blocks() -> None: def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_copy_blocks( run_copy_blocks(num_mappings=23,
num_mappings=23, num_layers=7, num_heads=17, head_size=16, num_layers=7,
block_size=8, num_blocks=1024, dtype=dtype) num_heads=17,
head_size=16,
block_size=8,
num_blocks=1024,
dtype=dtype)
def test_reshape_and_cache() -> None: def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache( run_reshape_and_cache(num_tokens=3,
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_heads=2,
dtype=dtype) head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)
def test_gather_cached_kv() -> None: def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv( run_gather_cached_kv(num_tokens=3,
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_heads=2,
dtype=dtype) head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)
...@@ -14,8 +14,10 @@ class RefRMSNorm(nn.Module): ...@@ -14,8 +14,10 @@ class RefRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1,
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states return self.weight * hidden_states
......
...@@ -8,8 +8,8 @@ from vllm import pos_encoding_ops ...@@ -8,8 +8,8 @@ from vllm import pos_encoding_ops
def rotate_half(x: torch.Tensor) -> torch.Tensor: def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :] x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
...@@ -38,7 +38,7 @@ class RefRotaryEmbeddingNeox(nn.Module): ...@@ -38,7 +38,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings. # Create cos and sin embeddings.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float() t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float()) freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
...@@ -49,16 +49,15 @@ class RefRotaryEmbeddingNeox(nn.Module): ...@@ -49,16 +49,15 @@ class RefRotaryEmbeddingNeox(nn.Module):
def forward( def forward(
self, self,
positions: torch.Tensor, # [num_tokens] positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size] query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., : self.rotary_dim] query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim:]
key_rot = key[..., : self.rotary_dim] key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1) query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1)
...@@ -85,12 +84,18 @@ def run_rotary_embedding_neox( ...@@ -85,12 +84,18 @@ def run_rotary_embedding_neox(
dtype: torch.dtype, dtype: torch.dtype,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
positions = torch.randint(0, max_position, (num_tokens,), device='cuda') positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') query = torch.randn(num_tokens,
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') num_heads * head_size,
dtype=dtype,
device='cuda')
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
# Create the rotary embedding. # Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos() cos = freqs.cos()
......
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
......
...@@ -35,7 +35,8 @@ class LogicalTokenBlock: ...@@ -35,7 +35,8 @@ class LogicalTokenBlock:
def append_tokens(self, token_ids: List[int]) -> None: def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots() assert len(token_ids) <= self.get_num_empty_slots()
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids) self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
......
...@@ -8,7 +8,7 @@ from vllm.utils import get_cpu_memory ...@@ -8,7 +8,7 @@ from vllm.utils import get_cpu_memory
logger = init_logger(__name__) logger = init_logger(__name__)
_GiB = 1 << 30 _GB = 1 << 30
class ModelConfig: class ModelConfig:
...@@ -106,6 +106,7 @@ class CacheConfig: ...@@ -106,6 +106,7 @@ class CacheConfig:
vLLM execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
""" """
def __init__( def __init__(
self, self,
block_size: int, block_size: int,
...@@ -114,7 +115,7 @@ class CacheConfig: ...@@ -114,7 +115,7 @@ class CacheConfig:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GiB self.swap_space_bytes = swap_space * _GB
self._verify_args() self._verify_args()
# Will be set after profiling. # Will be set after profiling.
...@@ -137,14 +138,13 @@ class CacheConfig: ...@@ -137,14 +138,13 @@ class CacheConfig:
num_gpus_per_node = parallel_config.tensor_parallel_size num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = ( msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
f"{cpu_memory_usage / _GiB:.2f} GiB out of " f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is " "allocated for the swap space.")
"allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory: if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg) raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory: elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warn("Possibly too large swap space. " + msg) logger.warning("Possibly too large swap space. " + msg)
class ParallelConfig: class ParallelConfig:
...@@ -157,6 +157,7 @@ class ParallelConfig: ...@@ -157,6 +157,7 @@ class ParallelConfig:
True if either pipeline_parallel_size or tensor_parallel_size is True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1. greater than 1.
""" """
def __init__( def __init__(
self, self,
pipeline_parallel_size: int, pipeline_parallel_size: int,
...@@ -189,12 +190,9 @@ class SchedulerConfig: ...@@ -189,12 +190,9 @@ class SchedulerConfig:
max_seq_len: Maximum length of a sequence (including prompt max_seq_len: Maximum length of a sequence (including prompt
and generated text). and generated text).
""" """
def __init__(
self, def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_num_batched_tokens: int, max_seq_len: int) -> None:
max_num_seqs: int,
max_seq_len: int
) -> None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
...@@ -241,7 +239,7 @@ def _get_and_verify_dtype( ...@@ -241,7 +239,7 @@ def _get_and_verify_dtype(
pass pass
else: else:
# Casting between float16 and bfloat16 is allowed with a warning. # Casting between float16 and bfloat16 is allowed with a warning.
logger.warn(f"Casting {config_dtype} to {torch_dtype}.") logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: if torch_dtype == torch.bfloat16:
......
...@@ -27,8 +27,9 @@ class BlockAllocator: ...@@ -27,8 +27,9 @@ class BlockAllocator:
# Initialize the free blocks. # Initialize the free blocks.
self.free_blocks: List[PhysicalTokenBlock] = [] self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks): for i in range(num_blocks):
block = PhysicalTokenBlock( block = PhysicalTokenBlock(device=device,
device=device, block_number=i, block_size=block_size) block_number=i,
block_size=block_size)
self.free_blocks.append(block) self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock: def allocate(self) -> PhysicalTokenBlock:
...@@ -84,10 +85,12 @@ class BlockSpaceManager: ...@@ -84,10 +85,12 @@ class BlockSpaceManager:
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction. # Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks return (num_free_gpu_blocks - num_required_blocks >=
self.watermark_blocks)
def allocate(self, seq_group: SequenceGroup) -> None: def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt. # NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs()[0]
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
...@@ -143,7 +146,8 @@ class BlockSpaceManager: ...@@ -143,7 +146,8 @@ class BlockSpaceManager:
for block in src_block_table: for block in src_block_table:
block.ref_count += 1 block.ref_count += 1
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by # NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group. # the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set() blocks: Set[PhysicalTokenBlock] = set()
......
...@@ -43,8 +43,7 @@ class SchedulerOutputs: ...@@ -43,8 +43,7 @@ class SchedulerOutputs:
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
def is_empty(self) -> bool: def is_empty(self) -> bool:
return (not self.blocks_to_swap_in return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
and not self.blocks_to_swap_out
and not self.blocks_to_copy) and not self.blocks_to_copy)
...@@ -61,7 +60,7 @@ class Scheduler: ...@@ -61,7 +60,7 @@ class Scheduler:
self.log_stats = log_stats self.log_stats = log_stats
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs') self.policy = PolicyFactory.get_policy(policy_name="fcfs")
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManager(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
...@@ -102,7 +101,8 @@ class Scheduler: ...@@ -102,7 +101,8 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]: def _schedule(
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
...@@ -160,7 +160,8 @@ class Scheduler: ...@@ -160,7 +160,8 @@ class Scheduler:
num_curr_seqs = sum( num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running) for seq_group in self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break break
seq_group = self.swapped.pop(0) seq_group = self.swapped.pop(0)
...@@ -170,8 +171,7 @@ class Scheduler: ...@@ -170,8 +171,7 @@ class Scheduler:
num_batched_tokens = sum( num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running for seq_group in self.running)
)
# Join waiting sequences if possible. # Join waiting sequences if possible.
prompt_group_ids: List[str] = [] prompt_group_ids: List[str] = []
...@@ -191,7 +191,7 @@ class Scheduler: ...@@ -191,7 +191,7 @@ class Scheduler:
num_prompt_tokens = seq_group.get_seqs()[0].get_len() num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens >= self.scheduler_config.max_seq_len: if num_prompt_tokens >= self.scheduler_config.max_seq_len:
logger.warn( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of " " and exceeds limit of "
f"{self.scheduler_config.max_seq_len}") f"{self.scheduler_config.max_seq_len}")
...@@ -206,17 +206,19 @@ class Scheduler: ...@@ -206,17 +206,19 @@ class Scheduler:
break break
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens if (num_batched_tokens + num_prompt_tokens >
> self.scheduler_config.max_num_batched_tokens): self.scheduler_config.max_num_batched_tokens):
break break
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING) num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum( num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running) for seq_group in self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break break
seq_group = self.waiting.pop(0) seq_group = self.waiting.pop(0)
...@@ -240,12 +242,11 @@ class Scheduler: ...@@ -240,12 +242,11 @@ class Scheduler:
elapsed_time = now - self.last_logging_time elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC: if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now self.last_logging_time = now
self.num_input_tokens = [ self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
(t, n) for t, n in self.num_input_tokens if now - t < _LOGGING_INTERVAL_SEC]
if now - t < _LOGGING_INTERVAL_SEC
]
if len(self.num_input_tokens) > 1: if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1]) total_num_tokens = sum(n
for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0] window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window avg_throughput = total_num_tokens / window
else: else:
...@@ -258,26 +259,30 @@ class Scheduler: ...@@ -258,26 +259,30 @@ class Scheduler:
total_num_cpu_blocks = self.cache_config.num_cpu_blocks total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0: if total_num_cpu_blocks > 0:
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks() num_free_cpu_blocks = (
self.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else: else:
cpu_cache_usage = 0.0 cpu_cache_usage = 0.0
logger.info( logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Throughput: {avg_throughput:.1f} tokens/s, " f"Running: {len(self.running)} reqs, "
f"Running: {len(self.running)} reqs, " f"Swapped: {len(self.swapped)} reqs, "
f"Swapped: {len(self.swapped)} reqs, " f"Pending: {len(self.waiting)} reqs, "
f"Pending: {len(self.waiting)} reqs, " f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids, ignored_seq_groups return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]: def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[SequenceGroup]]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule() (scheduler_outputs, prompt_group_ids,
ignored_seq_groups) = self._schedule()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
...@@ -311,8 +316,8 @@ class Scheduler: ...@@ -311,8 +316,8 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id: if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search). # The sequence is a fork of the parent sequence (beam
# Free the current sequence. # search). Free the current sequence.
self.block_manager.free(seq) self.block_manager.free(seq)
# Fork the parent sequence. # Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id) parent_seq = seq_group.find(output.parent_seq_id)
...@@ -385,7 +390,7 @@ class Scheduler: ...@@ -385,7 +390,7 @@ class Scheduler:
elif preemption_mode == PreemptionMode.SWAP: elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out) self._preempt_by_swap(seq_group, blocks_to_swap_out)
else: else:
assert False, 'Invalid preemption mode.' assert False, "Invalid preemption mode."
def _preempt_by_recompute( def _preempt_by_recompute(
self, self,
......
...@@ -12,11 +12,11 @@ class EngineArgs: ...@@ -12,11 +12,11 @@ class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str model: str
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = 'auto'
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
use_dummy_weights: bool = False use_dummy_weights: bool = False
dtype: str = "auto" dtype: str = 'auto'
seed: int = 0 seed: int = 0
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
...@@ -35,76 +35,101 @@ class EngineArgs: ...@@ -35,76 +35,101 @@ class EngineArgs:
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', parser.add_argument(
help='name or path of the huggingface model to use') '--model',
parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer, type=str,
help='name or path of the huggingface tokenizer to use') default='facebook/opt-125m',
parser.add_argument('--tokenizer-mode', type=str, help='name or path of the huggingface model to use')
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'], choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast ' help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will ' 'tokenizer if available, and "slow" will '
'always use the slow tokenizer.') 'always use the slow tokenizer.')
parser.add_argument('--download-dir', type=str, parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
'huggingface') 'huggingface')
parser.add_argument('--use-np-weights', action='store_true', parser.add_argument('--use-np-weights',
action='store_true',
help='save a numpy copy of model weights for ' help='save a numpy copy of model weights for '
'faster loading. This can increase the disk ' 'faster loading. This can increase the disk '
'usage by up to 2x.') 'usage by up to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true', parser.add_argument('--use-dummy-weights',
action='store_true',
help='use dummy values for model weights') help='use dummy values for model weights')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype, parser.add_argument(
choices=['auto', 'half', 'bfloat16', 'float'], '--dtype',
help='data type for model weights and activations. ' type=str,
'The "auto" option will use FP16 precision ' default=EngineArgs.dtype,
'for FP32 and FP16 models, and BF16 precision ' choices=['auto', 'half', 'bfloat16', 'float'],
'for BF16 models.') help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', action='store_true', parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be ' help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU') 'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size, default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages') help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas') help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed, parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='random seed') help='random seed')
parser.add_argument('--swap-space', type=int, parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU') help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for' help='the percentage of GPU memory to be used for'
'the model executor') 'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per ' help='maximum number of batched tokens per '
'iteration') 'iteration')
parser.add_argument('--max-num-seqs', type=int, parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics') help='disable logging statistics')
return parser return parser
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass. # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments. # Set the attributes from the parsed arguments.
...@@ -115,18 +140,19 @@ class EngineArgs: ...@@ -115,18 +140,19 @@ class EngineArgs:
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs. # Initialize the configs.
model_config = ModelConfig( model_config = ModelConfig(self.model, self.tokenizer,
self.model, self.tokenizer, self.tokenizer_mode, self.download_dir, self.tokenizer_mode, self.download_dir,
self.use_np_weights, self.use_dummy_weights, self.dtype, self.seed) self.use_np_weights, self.use_dummy_weights,
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space) self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
max_seq_len = min( model_max_len = getattr(model_config.hf_config,
self.max_num_batched_tokens, 'max_position_embeddings', float('inf'))
getattr(model_config.hf_config, "max_position_embeddings", max_seq_len = min(self.max_num_batched_tokens, model_max_len)
float("inf")))
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_seq_len) self.max_num_seqs, max_seq_len)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
...@@ -140,12 +166,13 @@ class AsyncEngineArgs(EngineArgs): ...@@ -140,12 +166,13 @@ class AsyncEngineArgs(EngineArgs):
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', action='store_true', parser.add_argument('--engine-use-ray',
action='store_true',
help='use Ray to start the LLM engine in a ' help='use Ray to start the LLM engine in a '
'separate process as the server process.') 'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true', parser.add_argument('--disable-log-requests',
action='store_true',
help='disable logging requests') help='disable logging requests')
return parser return parser
...@@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams ...@@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMEngine: class AsyncLLMEngine:
...@@ -35,8 +35,13 @@ class AsyncLLMEngine: ...@@ -35,8 +35,13 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine. *args, *kwargs: Arguments for LLMEngine.
""" """
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None: def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
...@@ -76,12 +81,11 @@ class AsyncLLMEngine: ...@@ -76,12 +81,11 @@ class AsyncLLMEngine:
self.request_events[request_id].set() self.request_events[request_id].set()
async def generate( async def generate(
self, self,
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
) -> RequestOutput:
"""Generate outputs for a request. """Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
...@@ -117,14 +121,17 @@ class AsyncLLMEngine: ...@@ -117,14 +121,17 @@ class AsyncLLMEngine:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote( await self.engine.add_request.remote(
request_id, prompt, sampling_params, request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time)
else: else:
self.engine.add_request( self.engine.add_request(request_id,
request_id, prompt, sampling_params, prompt,
prompt_token_ids=prompt_token_ids, sampling_params,
arrival_time=arrival_time) prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The vLLM engine does not have a background loop that keeps # The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking # processing incoming requests. Therefore, we need to keep kicking
...@@ -200,7 +207,8 @@ class AsyncLLMEngine: ...@@ -200,7 +207,8 @@ class AsyncLLMEngine:
self.kicking_request_id = None self.kicking_request_id = None
@classmethod @classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
...@@ -211,8 +219,9 @@ class AsyncLLMEngine: ...@@ -211,8 +219,9 @@ class AsyncLLMEngine:
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(engine_args.worker_use_ray, engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs, *engine_configs,
distributed_init_method, devices, distributed_init_method,
devices,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment