Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
16422ea7
Unverified
Commit
16422ea7
authored
Aug 13, 2024
by
youkaichao
Committed by
GitHub
Aug 13, 2024
Browse files
[misc][plugin] add plugin system implementation (#7426)
parent
373538f9
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
161 additions
and
101 deletions
+161
-101
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+4
-0
requirements-common.txt
requirements-common.txt
+1
-0
tests/conftest.py
tests/conftest.py
+26
-0
tests/distributed/test_distributed_oot.py
tests/distributed/test_distributed_oot.py
+6
-0
tests/entrypoints/openai/test_oot_registration.py
tests/entrypoints/openai/test_oot_registration.py
+27
-79
tests/models/test_oot_registration.py
tests/models/test_oot_registration.py
+15
-20
tests/plugins/vllm_add_dummy_model/setup.py
tests/plugins/vllm_add_dummy_model/setup.py
+9
-0
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py
...ins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py
+26
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-0
vllm/envs.py
vllm/envs.py
+9
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-1
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+31
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+3
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
16422ea7
...
...
@@ -77,11 +77,13 @@ steps:
-
pytest -v -s core
-
label
:
Entrypoints Test
# 20min
working_dir
:
"
/vllm-workspace/tests"
fast_check
:
true
mirror_hardwares
:
[
amd
]
source_file_dependencies
:
-
vllm/
commands
:
-
pip install -e ./plugins/vllm_add_dummy_model
-
pytest -v -s entrypoints/llm
-
pytest -v -s entrypoints/openai
...
...
@@ -154,6 +156,7 @@ steps:
-
vllm/
-
tests/models
commands
:
-
pip install -e ./plugins/vllm_add_dummy_model
-
pytest -v -s models -m \"not vlm\"
-
label
:
Vision Language Models Test
# 42min
...
...
@@ -289,6 +292,7 @@ steps:
-
pytest -v -s distributed/test_chunked_prefill_distributed.py
-
pytest -v -s distributed/test_multimodal_broadcast.py
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
-
pytest -v -s distributed/test_distributed_oot.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
...
...
requirements-common.txt
View file @
16422ea7
...
...
@@ -23,4 +23,5 @@ pyzmq
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
compressed-tensors == 0.5.0
tests/conftest.py
View file @
16422ea7
import
contextlib
import
gc
import
json
import
os
import
sys
import
tempfile
from
collections
import
UserList
from
enum
import
Enum
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
...
...
@@ -11,6 +13,7 @@ import pytest
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
huggingface_hub
import
snapshot_download
from
PIL
import
Image
from
transformers
import
(
AutoModelForCausalLM
,
AutoModelForSeq2SeqLM
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
,
...
...
@@ -757,3 +760,26 @@ def num_gpus_available():
in current process."""
return
cuda_device_count_stateless
()
temp_dir
=
tempfile
.
gettempdir
()
_dummy_path
=
os
.
path
.
join
(
temp_dir
,
"dummy_opt"
)
@
pytest
.
fixture
def
dummy_opt_path
():
json_path
=
os
.
path
.
join
(
_dummy_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
_dummy_path
):
snapshot_download
(
repo_id
=
"facebook/opt-125m"
,
local_dir
=
_dummy_path
,
ignore_patterns
=
[
"*.bin"
,
"*.bin.index.json"
,
"*.pt"
,
"*.h5"
,
"*.msgpack"
])
assert
os
.
path
.
exists
(
json_path
)
with
open
(
json_path
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
config
[
"architectures"
]
=
[
"MyOPTForCausalLM"
]
with
open
(
json_path
,
"w"
)
as
f
:
json
.
dump
(
config
,
f
)
return
_dummy_path
tests/distributed/test_distributed_oot.py
0 → 100644
View file @
16422ea7
from
..entrypoints.openai.test_oot_registration
import
(
run_and_test_dummy_opt_api_server
)
def
test_distributed_oot
(
dummy_opt_path
:
str
):
run_and_test_dummy_opt_api_server
(
dummy_opt_path
,
tp
=
2
)
tests/entrypoints/openai/test_oot_registration.py
View file @
16422ea7
import
sys
import
time
import
torch
from
openai
import
OpenAI
,
OpenAIError
from
vllm
import
ModelRegistry
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.utils
import
get_open_port
from
...utils
import
VLLM_PATH
,
RemoteOpenAIServer
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
assert
chatml_jinja_path
.
exists
()
class
MyOPTForCausalLM
(
OPTForCausalLM
):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
# this dummy model always predicts the first token
logits
=
super
().
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
.
zero_
()
logits
[:,
0
]
+=
1.0
return
logits
def
server_function
(
port
:
int
):
# register our dummy model
ModelRegistry
.
register_model
(
"OPTForCausalLM"
,
MyOPTForCausalLM
)
sys
.
argv
=
[
"placeholder.py"
]
+
[
"--model"
,
"facebook/opt-125m"
,
def
run_and_test_dummy_opt_api_server
(
model
,
tp
=
1
):
# the model is registered through the plugin
server_args
=
[
"--gpu-memory-utilization"
,
"0.10"
,
"--dtype"
,
"float32"
,
"--api-key"
,
"token-abc123"
,
"--port"
,
str
(
port
),
"--chat-template"
,
str
(
chatml_jinja_path
),
"--load-format"
,
"dummy"
,
"-tp"
,
f
"
{
tp
}
"
,
]
import
runpy
runpy
.
run_module
(
'vllm.entrypoints.openai.api_server'
,
run_name
=
'__main__'
)
def
test_oot_registration_for_api_server
():
port
=
get_open_port
()
ctx
=
torch
.
multiprocessing
.
get_context
()
server
=
ctx
.
Process
(
target
=
server_function
,
args
=
(
port
,
))
server
.
start
()
try
:
client
=
OpenAI
(
base_url
=
f
"http://localhost:
{
port
}
/v1"
,
api_key
=
"token-abc123"
,
)
now
=
time
.
time
()
while
True
:
try
:
with
RemoteOpenAIServer
(
model
,
server_args
)
as
server
:
client
=
server
.
get_client
()
completion
=
client
.
chat
.
completions
.
create
(
model
=
"facebook/opt-125m"
,
model
=
model
,
messages
=
[{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
...
...
@@ -74,21 +31,12 @@ def test_oot_registration_for_api_server():
}],
temperature
=
0
,
)
break
except
OpenAIError
as
e
:
if
"Connection error"
in
str
(
e
):
time
.
sleep
(
3
)
if
time
.
time
()
-
now
>
RemoteOpenAIServer
.
MAX_START_WAIT_S
:
msg
=
"Server did not start in time"
raise
RuntimeError
(
msg
)
from
e
else
:
raise
e
finally
:
server
.
terminate
()
generated_text
=
completion
.
choices
[
0
].
message
.
content
assert
generated_text
is
not
None
# make sure only the first token is generated
# TODO(youkaichao): Fix the test with plugin
rest
=
generated_text
.
replace
(
"<s>"
,
""
)
# noqa
# assert rest == ""
rest
=
generated_text
.
replace
(
"<s>"
,
""
)
assert
rest
==
""
def
test_oot_registration_for_api_server
(
dummy_opt_path
:
str
):
run_and_test_dummy_opt_api_server
(
dummy_opt_path
)
tests/models/test_oot_registration.py
View file @
16422ea7
from
typing
import
Optional
import
os
import
torch
import
pytest
from
vllm
import
LLM
,
ModelRegistry
,
SamplingParams
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm
import
LLM
,
SamplingParams
# NOTE: the order of the tests is important
# the first test does not load any plugins
# the second test loads the plugin
# they share the same process, so the plugin is loaded for the second test
class
MyOPTForCausalLM
(
OPTForCausalLM
):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
# this dummy model always predicts the first token
logits
=
super
().
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
.
zero_
()
logits
[:,
0
]
+=
1.0
return
logits
def
test_plugin
(
dummy_opt_path
):
os
.
environ
[
"VLLM_PLUGINS"
]
=
""
with
pytest
.
raises
(
Exception
)
as
excinfo
:
LLM
(
model
=
dummy_opt_path
,
load_format
=
"dummy"
)
assert
"are not supported for now"
in
str
(
excinfo
.
value
)
def
test_oot_registration
():
# register our dummy model
ModelRegistry
.
register_model
(
"OPTForCausalLM"
,
MyOPTForCausalLM
)
def
test_oot_registration
(
dummy_opt_path
):
os
.
environ
[
"VLLM_PLUGINS"
]
=
"register_dummy_model"
prompts
=
[
"Hello, my name is"
,
"The text does not matter"
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
"facebook/opt-125m
"
)
llm
=
LLM
(
model
=
dummy_opt_path
,
load_format
=
"dummy
"
)
first_token
=
llm
.
get_tokenizer
().
decode
(
0
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
...
...
tests/plugins/vllm_add_dummy_model/setup.py
0 → 100644
View file @
16422ea7
from
setuptools
import
setup
setup
(
name
=
'vllm_add_dummy_model'
,
version
=
'0.1'
,
packages
=
[
'vllm_add_dummy_model'
],
entry_points
=
{
'vllm.general_plugins'
:
[
"register_dummy_model = vllm_add_dummy_model:register"
]
})
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py
0 → 100644
View file @
16422ea7
from
typing
import
Optional
import
torch
from
vllm
import
ModelRegistry
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
class
MyOPTForCausalLM
(
OPTForCausalLM
):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
torch
.
Tensor
]:
# this dummy model always predicts the first token
logits
=
super
().
compute_logits
(
hidden_states
,
sampling_metadata
)
if
logits
is
not
None
:
logits
.
zero_
()
logits
[:,
0
]
+=
1.0
return
logits
def
register
():
# register our dummy model
if
"MyOPTForCausalLM"
not
in
ModelRegistry
.
get_supported_archs
():
ModelRegistry
.
register_model
(
"MyOPTForCausalLM"
,
MyOPTForCausalLM
)
vllm/engine/llm_engine.py
View file @
16422ea7
...
...
@@ -227,6 +227,9 @@ class LLMEngine:
)
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
...
...
vllm/envs.py
View file @
16422ea7
import
os
import
tempfile
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
if
TYPE_CHECKING
:
VLLM_HOST_IP
:
str
=
""
...
...
@@ -55,6 +55,7 @@ if TYPE_CHECKING:
VERBOSE
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
def
get_default_cache_root
():
...
...
@@ -362,6 +363,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS"
:
lambda
:
None
if
"VLLM_PLUGINS"
not
in
os
.
environ
else
os
.
environ
[
"VLLM_PLUGINS"
].
split
(
","
),
}
# end-env-vars-definition
...
...
vllm/model_executor/models/__init__.py
View file @
16422ea7
...
...
@@ -166,7 +166,7 @@ class ModelRegistry:
@
staticmethod
def
get_supported_archs
()
->
List
[
str
]:
return
list
(
_MODELS
.
keys
())
return
list
(
_MODELS
.
keys
())
+
list
(
_OOT_MODELS
.
keys
())
@
staticmethod
def
register_model
(
model_arch
:
str
,
model_cls
:
Type
[
nn
.
Module
]):
...
...
vllm/plugins/__init__.py
0 → 100644
View file @
16422ea7
import
logging
import
vllm.envs
as
envs
logger
=
logging
.
getLogger
(
__name__
)
def
load_general_plugins
():
"""WARNING: plugins can be loaded for multiple times in different
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""
import
sys
if
sys
.
version_info
<
(
3
,
10
):
from
importlib_metadata
import
entry_points
else
:
from
importlib.metadata
import
entry_points
allowed_plugins
=
envs
.
VLLM_PLUGINS
discovered_plugins
=
entry_points
(
group
=
'vllm.general_plugins'
)
for
plugin
in
discovered_plugins
:
logger
.
info
(
"Found general plugin: %s"
,
plugin
.
name
)
if
allowed_plugins
is
None
or
plugin
.
name
in
allowed_plugins
:
try
:
func
=
plugin
.
load
()
func
()
logger
.
info
(
"Loaded general plugin: %s"
,
plugin
.
name
)
except
Exception
:
logger
.
exception
(
"Failed to load general plugin: %s"
,
plugin
.
name
)
vllm/worker/worker_base.py
View file @
16422ea7
...
...
@@ -411,6 +411,9 @@ class WorkerWrapperBase:
# see https://github.com/NVIDIA/nccl/issues/1234
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
if
self
.
worker_class_fn
:
worker_class
=
self
.
worker_class_fn
()
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment