Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64172a97
Unverified
Commit
64172a97
authored
Mar 25, 2024
by
xwjiang2010
Committed by
GitHub
Mar 25, 2024
Browse files
[Feature] Add vision language model support. (#3042)
parent
f408d05c
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
529 additions
and
63 deletions
+529
-63
.buildkite/download-images.sh
.buildkite/download-images.sh
+18
-0
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+7
-1
examples/llava_example.py
examples/llava_example.py
+84
-0
requirements-dev.txt
requirements-dev.txt
+4
-0
tests/conftest.py
tests/conftest.py
+112
-14
tests/models/test_llava.py
tests/models/test_llava.py
+110
-0
tests/models/test_models.py
tests/models/test_models.py
+1
-1
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+1
-1
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+1
-1
tests/worker/test_swap.py
tests/worker/test_swap.py
+1
-1
vllm/config.py
vllm/config.py
+62
-17
vllm/core/scheduler.py
vllm/core/scheduler.py
+6
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+53
-3
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+16
-9
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+11
-6
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+21
-5
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+2
-1
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+4
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+4
-1
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+11
-1
No files found.
.buildkite/download-images.sh
0 → 100644
View file @
64172a97
#!/bin/bash
set
-ex
set
-o
pipefail
(
which wget
&&
which curl
)
||
(
apt-get update
&&
apt-get
install
-y
wget curl
)
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
mkdir
-p
images
cd
images
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg
cd
-
.buildkite/test-pipeline.yaml
View file @
64172a97
...
@@ -39,9 +39,15 @@ steps:
...
@@ -39,9 +39,15 @@ steps:
-
label
:
Models Test
-
label
:
Models Test
commands
:
commands
:
-
pytest -v -s models --forked
-
bash ../.buildkite/download-images.sh
-
pytest -v -s models --ignore=models/test_llava.py --forked
soft_fail
:
true
soft_fail
:
true
-
label
:
Llava Test
commands
:
-
bash ../.buildkite/download-images.sh
-
pytest -v -s models/test_llava.py
-
label
:
Prefix Caching Test
-
label
:
Prefix Caching Test
commands
:
commands
:
-
pytest -v -s prefix_caching
-
pytest -v -s prefix_caching
...
...
examples/llava_example.py
0 → 100644
View file @
64172a97
import
argparse
import
os
import
subprocess
import
torch
from
vllm
import
LLM
from
vllm.sequence
import
MultiModalData
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
def
run_llava_pixel_values
():
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
image_input_type
=
"pixel_values"
,
image_token_id
=
32000
,
image_input_shape
=
"1,3,336,336"
,
image_feature_size
=
576
,
)
prompt
=
"<image>"
*
576
+
(
"
\n
USER: What is the content of this image?
\n
ASSISTANT:"
)
# This should be provided by another online or offline component.
images
=
torch
.
load
(
"images/stop_sign_pixel_values.pt"
)
outputs
=
llm
.
generate
(
prompt
,
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
images
))
for
o
in
outputs
:
generated_text
=
o
.
outputs
[
0
].
text
print
(
generated_text
)
def
run_llava_image_features
():
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
image_input_type
=
"image_features"
,
image_token_id
=
32000
,
image_input_shape
=
"1,576,1024"
,
image_feature_size
=
576
,
)
prompt
=
"<image>"
*
576
+
(
"
\n
USER: What is the content of this image?
\n
ASSISTANT:"
)
# This should be provided by another online or offline component.
images
=
torch
.
load
(
"images/stop_sign_image_features.pt"
)
outputs
=
llm
.
generate
(
prompt
,
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
images
))
for
o
in
outputs
:
generated_text
=
o
.
outputs
[
0
].
text
print
(
generated_text
)
def
main
(
args
):
if
args
.
type
==
"pixel_values"
:
run_llava_pixel_values
()
else
:
run_llava_image_features
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Demo on Llava"
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
choices
=
[
"pixel_values"
,
"image_features"
],
default
=
"pixel_values"
,
help
=
"image input type"
)
args
=
parser
.
parse_args
()
# Download from s3
s3_bucket_path
=
"s3://air-example-data-2/vllm_opensource_llava/"
local_directory
=
"images"
# Make sure the local directory exists or create it
os
.
makedirs
(
local_directory
,
exist_ok
=
True
)
# Use AWS CLI to sync the directory
subprocess
.
check_call
(
[
"aws"
,
"s3"
,
"sync"
,
s3_bucket_path
,
local_directory
])
main
(
args
)
requirements-dev.txt
View file @
64172a97
...
@@ -24,6 +24,10 @@ openai
...
@@ -24,6 +24,10 @@ openai
requests
requests
ray
ray
peft
peft
awscli
# Benchmarking
# Benchmarking
aiohttp
aiohttp
# Multimodal
pillow
tests/conftest.py
View file @
64172a97
...
@@ -3,16 +3,39 @@ from typing import List, Optional, Tuple
...
@@ -3,16 +3,39 @@ from typing import List, Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
AutoModelForCausalLM
from
PIL
import
Image
from
transformers
import
(
AutoModelForCausalLM
,
AutoProcessor
,
LlavaForConditionalGeneration
)
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
,
VisionLanguageConfig
from
vllm.sequence
import
MultiModalData
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
_LONG_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"summary.txt"
)]
_LONG_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"summary.txt"
)]
# Multi modal related
_PIXEL_VALUES_FILES
=
[
os
.
path
.
join
(
_TEST_DIR
,
"images"
,
filename
)
for
filename
in
[
"stop_sign_pixel_values.pt"
,
"cherry_blossom_pixel_values.pt"
]
]
_IMAGE_FEATURES_FILES
=
[
os
.
path
.
join
(
_TEST_DIR
,
"images"
,
filename
)
for
filename
in
[
"stop_sign_image_features.pt"
,
"cherry_blossom_image_features.pt"
]
]
_IMAGE_FILES
=
[
os
.
path
.
join
(
_TEST_DIR
,
"images"
,
filename
)
for
filename
in
[
"stop_sign.jpg"
,
"cherry_blossom.jpg"
]
]
_IMAGE_PROMPTS
=
[
"<image>
\n
USER: What's the content of the image?
\n
ASSISTANT:"
,
"<image>
\n
USER: What is the season?
\n
ASSISTANT:"
]
assert
len
(
_PIXEL_VALUES_FILES
)
==
len
(
_IMAGE_FEATURES_FILES
)
==
len
(
_IMAGE_FILES
)
==
len
(
_IMAGE_PROMPTS
)
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
,
"r"
)
as
f
:
...
@@ -20,6 +43,39 @@ def _read_prompts(filename: str) -> List[str]:
...
@@ -20,6 +43,39 @@ def _read_prompts(filename: str) -> List[str]:
return
prompts
return
prompts
@
pytest
.
fixture
(
scope
=
"session"
)
def
hf_image_prompts
()
->
List
[
str
]:
return
_IMAGE_PROMPTS
@
pytest
.
fixture
(
scope
=
"session"
)
def
hf_images
()
->
List
[
Image
.
Image
]:
return
[
Image
.
open
(
filename
)
for
filename
in
_IMAGE_FILES
]
@
pytest
.
fixture
()
def
vllm_images
(
request
)
->
"torch.Tensor"
:
vision_language_config
=
request
.
getfixturevalue
(
"model_and_config"
)[
1
]
all_images
=
[]
if
vision_language_config
.
image_input_type
==
(
VisionLanguageConfig
.
ImageInputType
.
IMAGE_FEATURES
):
filenames
=
_IMAGE_FEATURES_FILES
else
:
filenames
=
_PIXEL_VALUES_FILES
for
filename
in
filenames
:
all_images
.
append
(
torch
.
load
(
filename
))
return
torch
.
concat
(
all_images
,
dim
=
0
)
@
pytest
.
fixture
()
def
vllm_image_prompts
(
request
)
->
List
[
str
]:
vision_language_config
=
request
.
getfixturevalue
(
"model_and_config"
)[
1
]
return
[
"<image>"
*
(
vision_language_config
.
image_feature_size
-
1
)
+
p
for
p
in
_IMAGE_PROMPTS
]
@
pytest
.
fixture
@
pytest
.
fixture
def
example_prompts
()
->
List
[
str
]:
def
example_prompts
()
->
List
[
str
]:
prompts
=
[]
prompts
=
[]
...
@@ -42,6 +98,10 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -42,6 +98,10 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"float"
:
torch
.
float
,
"float"
:
torch
.
float
,
}
}
_VISION_LANGUAGE_MODELS
=
{
"llava-hf/llava-1.5-7b-hf"
:
LlavaForConditionalGeneration
,
}
class
HfRunner
:
class
HfRunner
:
...
@@ -53,11 +113,24 @@ class HfRunner:
...
@@ -53,11 +113,24 @@ class HfRunner:
)
->
None
:
)
->
None
:
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model_name
=
model_name
model_name
,
if
model_name
not
in
_VISION_LANGUAGE_MODELS
:
torch_dtype
=
torch_dtype
,
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
trust_remote_code
=
True
,
model_name
,
).
cuda
()
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
).
cuda
()
self
.
processor
=
None
else
:
self
.
model
=
_VISION_LANGUAGE_MODELS
[
model_name
].
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
).
cuda
()
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
)
if
tokenizer_name
is
None
:
if
tokenizer_name
is
None
:
tokenizer_name
=
model_name
tokenizer_name
=
model_name
self
.
tokenizer
=
get_tokenizer
(
tokenizer_name
,
trust_remote_code
=
True
)
self
.
tokenizer
=
get_tokenizer
(
tokenizer_name
,
trust_remote_code
=
True
)
...
@@ -65,13 +138,28 @@ class HfRunner:
...
@@ -65,13 +138,28 @@ class HfRunner:
def
generate
(
def
generate
(
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
:
List
[
Tuple
[
List
[
int
],
str
]]
=
[]
outputs
:
List
[
Tuple
[
List
[
int
],
str
]]
=
[]
for
prompt
in
prompts
:
if
images
:
input_ids
=
self
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
assert
len
(
prompts
)
==
len
(
images
)
for
i
,
prompt
in
enumerate
(
prompts
):
if
self
.
model_name
not
in
_VISION_LANGUAGE_MODELS
:
input_ids
=
self
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
inputs
=
{
"input_ids"
:
input_ids
.
cuda
()}
else
:
image
=
images
[
i
]
if
images
else
None
inputs
=
self
.
processor
(
text
=
prompt
,
images
=
image
,
return_tensors
=
"pt"
)
inputs
=
{
key
:
value
.
cuda
()
if
value
is
not
None
else
None
for
key
,
value
in
inputs
.
items
()
}
output_ids
=
self
.
model
.
generate
(
output_ids
=
self
.
model
.
generate
(
input
_ids
.
cuda
()
,
**
input
s
,
use_cache
=
True
,
use_cache
=
True
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -88,10 +176,12 @@ class HfRunner:
...
@@ -88,10 +176,12 @@ class HfRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
images
:
Optional
[
"torch.Tensor"
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
=
self
.
generate
(
prompts
,
outputs
=
self
.
generate
(
prompts
,
do_sample
=
False
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
)
max_new_tokens
=
max_tokens
,
images
=
images
)
for
i
in
range
(
len
(
outputs
)):
for
i
in
range
(
len
(
outputs
)):
output_ids
,
output_str
=
outputs
[
i
]
output_ids
,
output_str
=
outputs
[
i
]
outputs
[
i
]
=
(
output_ids
[
0
],
output_str
[
0
])
outputs
[
i
]
=
(
output_ids
[
0
],
output_str
[
0
])
...
@@ -183,9 +273,16 @@ class VllmRunner:
...
@@ -183,9 +273,16 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
images
:
Optional
[
"torch.Tensor"
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
req_outputs
=
self
.
model
.
generate
(
prompts
,
if
images
is
not
None
:
sampling_params
=
sampling_params
)
assert
len
(
prompts
)
==
images
.
shape
[
0
]
req_outputs
=
self
.
model
.
generate
(
prompts
,
sampling_params
=
sampling_params
,
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
images
)
if
images
is
not
None
else
None
)
outputs
=
[]
outputs
=
[]
for
req_output
in
req_outputs
:
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
prompt_str
=
req_output
.
prompt
...
@@ -222,9 +319,10 @@ class VllmRunner:
...
@@ -222,9 +319,10 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
images
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
outputs
=
self
.
generate
(
prompts
,
greedy_params
)
outputs
=
self
.
generate
(
prompts
,
greedy_params
,
images
=
images
)
return
[(
output_ids
[
0
],
output_str
[
0
])
return
[(
output_ids
[
0
],
output_str
[
0
])
for
output_ids
,
output_str
in
outputs
]
for
output_ids
,
output_str
in
outputs
]
...
...
tests/models/test_llava.py
0 → 100644
View file @
64172a97
import
gc
from
dataclasses
import
fields
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Tuple
import
pytest
import
torch
from
transformers
import
AutoTokenizer
from
vllm.config
import
VisionLanguageConfig
model_and_vl_config
=
[
(
"llava-hf/llava-1.5-7b-hf"
,
VisionLanguageConfig
(
image_input_type
=
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
,
image_feature_size
=
576
,
image_token_id
=
32000
,
image_input_shape
=
(
1
,
3
,
336
,
336
))),
(
"llava-hf/llava-1.5-7b-hf"
,
VisionLanguageConfig
(
image_input_type
=
VisionLanguageConfig
.
ImageInputType
.
IMAGE_FEATURES
,
image_feature_size
=
576
,
image_token_id
=
32000
,
image_input_shape
=
(
1
,
576
,
1024
)))
]
def
as_dict
(
vision_language_config
:
VisionLanguageConfig
)
->
Dict
:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result
=
{}
for
field
in
fields
(
vision_language_config
):
value
=
getattr
(
vision_language_config
,
field
.
name
)
if
isinstance
(
value
,
Enum
):
result
[
field
.
name
]
=
value
.
name
.
lower
()
elif
isinstance
(
value
,
tuple
):
result
[
field
.
name
]
=
","
.
join
([
str
(
item
)
for
item
in
value
])
else
:
result
[
field
.
name
]
=
value
return
result
def
sanitize_vllm_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
],
vision_language_config
:
VisionLanguageConfig
,
model_id
:
str
):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
image_token_str
=
tokenizer
.
decode
(
vision_language_config
.
image_token_id
)
image_token_str_len
=
len
(
image_token_str
)
input_ids
,
output_str
=
vllm_output
sanitized_input_ids
=
input_ids
[
0
:
2
]
+
input_ids
[
2
+
vision_language_config
.
image_feature_size
-
1
:]
sanitzied_output_str
=
output_str
[
vision_language_config
.
image_feature_size
*
image_token_str_len
:]
return
sanitized_input_ids
,
sanitzied_output_str
@
pytest
.
mark
.
parametrize
(
"worker_use_ray"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"model_and_config"
,
model_and_vl_config
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
hf_runner
,
vllm_runner
,
hf_image_prompts
,
hf_images
,
vllm_image_prompts
,
vllm_images
,
model_and_config
:
tuple
,
dtype
:
str
,
max_tokens
:
int
,
worker_use_ray
:
bool
)
->
None
:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the raw images as input.
For vllm runner, we provide image tensors and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id
,
vision_language_config
=
model_and_config
hf_model
=
hf_runner
(
model_id
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
hf_image_prompts
,
max_tokens
,
images
=
hf_images
)
del
hf_model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
vllm_model
=
vllm_runner
(
model_id
,
dtype
=
dtype
,
worker_use_ray
=
worker_use_ray
,
**
as_dict
(
vision_language_config
))
vllm_outputs
=
vllm_model
.
generate_greedy
(
vllm_image_prompts
,
max_tokens
,
images
=
vllm_images
)
del
vllm_model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
for
i
in
range
(
len
(
hf_image_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_str
=
sanitize_vllm_output
(
vllm_outputs
[
i
],
vision_language_config
,
model_id
)
assert
hf_output_str
==
vllm_output_str
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_str
!
r
}
\n
vLLM:
{
vllm_output_str
!
r
}
"
)
assert
hf_output_ids
==
vllm_output_ids
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
tests/models/test_models.py
View file @
64172a97
...
@@ -25,7 +25,7 @@ MODELS = [
...
@@ -25,7 +25,7 @@ MODELS = [
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"
float
"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"
half
"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
def
test_models
(
hf_runner
,
hf_runner
,
...
...
tests/spec_decode/utils.py
View file @
64172a97
...
@@ -109,7 +109,7 @@ def create_worker(cls: type,
...
@@ -109,7 +109,7 @@ def create_worker(cls: type,
)
)
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
_
)
=
engine_args
.
create_engine_configs
()
device_config
,
_
,
_
)
=
engine_args
.
create_engine_configs
()
distributed_init_method
=
get_distributed_init_method
(
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
get_ip
(),
get_open_port
())
...
...
tests/worker/test_model_runner.py
View file @
64172a97
...
@@ -35,7 +35,7 @@ def test_prepare_prompt(batch_size):
...
@@ -35,7 +35,7 @@ def test_prepare_prompt(batch_size):
prompt_len
-
1
)
prompt_len
-
1
)
selected_token_start_idx
+=
prompt_len
selected_token_start_idx
+=
prompt_len
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
(
input_tokens
,
input_positions
,
attn_metadata
,
return_prompt_lens
,
_
,
_
,
_
,
_
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
_
,
_
)
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_prompt_lens
==
prompt_lens
assert
return_prompt_lens
==
prompt_lens
# Verify input metadata is correct for prompts.
# Verify input metadata is correct for prompts.
...
...
tests/worker/test_swap.py
View file @
64172a97
...
@@ -11,7 +11,7 @@ def test_swap() -> None:
...
@@ -11,7 +11,7 @@ def test_swap() -> None:
dtype
=
"half"
,
dtype
=
"half"
,
load_format
=
"dummy"
)
load_format
=
"dummy"
)
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
_
)
=
engine_args
.
create_engine_configs
()
device_config
,
_
,
_
)
=
engine_args
.
create_engine_configs
()
cache_config
.
num_gpu_blocks
=
100
cache_config
.
num_gpu_blocks
=
100
cache_config
.
num_cpu_blocks
=
100
cache_config
.
num_cpu_blocks
=
100
...
...
vllm/config.py
View file @
64172a97
import
enum
import
json
import
json
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -8,7 +9,7 @@ from packaging.version import Version
...
@@ -8,7 +9,7 @@ from packaging.version import Version
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
get_cpu_memory
,
get_nvcc_cuda_version
,
is_hip
,
is_neuron
from
vllm.utils
import
get_cpu_memory
,
get_nvcc_cuda_version
,
is_hip
,
is_neuron
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -118,8 +119,9 @@ class ModelConfig:
...
@@ -118,8 +119,9 @@ class ModelConfig:
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
)
code_revision
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_config
,
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
max_model_len
)
max_model_len
)
self
.
_verify_load_format
()
self
.
_verify_load_format
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
...
@@ -218,7 +220,7 @@ class ModelConfig:
...
@@ -218,7 +220,7 @@ class ModelConfig:
self
,
self
,
parallel_config
:
"ParallelConfig"
,
parallel_config
:
"ParallelConfig"
,
)
->
None
:
)
->
None
:
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
total_num_attention_heads
=
self
.
hf_
text_
config
.
num_attention_heads
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
parallel_config
.
tensor_parallel_size
if
total_num_attention_heads
%
tensor_parallel_size
!=
0
:
if
total_num_attention_heads
%
tensor_parallel_size
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -226,7 +228,7 @@ class ModelConfig:
...
@@ -226,7 +228,7 @@ class ModelConfig:
" must be divisible by tensor parallel size "
" must be divisible by tensor parallel size "
f
"(
{
tensor_parallel_size
}
)."
)
f
"(
{
tensor_parallel_size
}
)."
)
total_num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
total_num_hidden_layers
=
self
.
hf_
text_
config
.
num_hidden_layers
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
if
total_num_hidden_layers
%
pipeline_parallel_size
!=
0
:
if
total_num_hidden_layers
%
pipeline_parallel_size
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -241,22 +243,23 @@ class ModelConfig:
...
@@ -241,22 +243,23 @@ class ModelConfig:
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
# and if it's False, return None.
if
(
hasattr
(
self
.
hf_config
,
"use_sliding_window"
)
if
(
hasattr
(
self
.
hf_
text_
config
,
"use_sliding_window"
)
and
not
self
.
hf_config
.
use_sliding_window
):
and
not
self
.
hf_
text_
config
.
use_sliding_window
):
return
None
return
None
return
getattr
(
self
.
hf_config
,
"sliding_window"
,
None
)
return
getattr
(
self
.
hf_
text_
config
,
"sliding_window"
,
None
)
def
get_vocab_size
(
self
)
->
int
:
def
get_vocab_size
(
self
)
->
int
:
return
self
.
hf_config
.
vocab_size
return
self
.
hf_
text_
config
.
vocab_size
def
get_hidden_size
(
self
)
->
int
:
def
get_hidden_size
(
self
)
->
int
:
return
self
.
hf_config
.
hidden_size
return
self
.
hf_
text_
config
.
hidden_size
def
get_head_size
(
self
)
->
int
:
def
get_head_size
(
self
)
->
int
:
if
hasattr
(
self
.
hf_config
,
"head_dim"
):
if
hasattr
(
self
.
hf_
text_
config
,
"head_dim"
):
return
self
.
hf_config
.
head_dim
return
self
.
hf_
text_
config
.
head_dim
# FIXME(woosuk): This may not be true for all models.
# FIXME(woosuk): This may not be true for all models.
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
return
(
self
.
hf_text_config
.
hidden_size
//
self
.
hf_text_config
.
num_attention_heads
)
def
get_total_num_kv_heads
(
self
)
->
int
:
def
get_total_num_kv_heads
(
self
)
->
int
:
"""Returns the total number of KV heads."""
"""Returns the total number of KV heads."""
...
@@ -268,7 +271,7 @@ class ModelConfig:
...
@@ -268,7 +271,7 @@ class ModelConfig:
new_decoder_arch_falcon
=
(
new_decoder_arch_falcon
=
(
self
.
hf_config
.
model_type
in
falcon_model_types
self
.
hf_config
.
model_type
in
falcon_model_types
and
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
))
and
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
))
if
not
new_decoder_arch_falcon
and
getattr
(
self
.
hf_config
,
if
not
new_decoder_arch_falcon
and
getattr
(
self
.
hf_
text_
config
,
"multi_query"
,
False
):
"multi_query"
,
False
):
# Multi-query attention, only one KV head.
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
# Currently, tensor parallelism is not supported in this case.
...
@@ -284,13 +287,13 @@ class ModelConfig:
...
@@ -284,13 +287,13 @@ class ModelConfig:
"multi_query_group_num"
,
"multi_query_group_num"
,
]
]
for
attr
in
attributes
:
for
attr
in
attributes
:
num_kv_heads
=
getattr
(
self
.
hf_config
,
attr
,
None
)
num_kv_heads
=
getattr
(
self
.
hf_
text_
config
,
attr
,
None
)
if
num_kv_heads
is
not
None
:
if
num_kv_heads
is
not
None
:
return
num_kv_heads
return
num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
# equal to the number of attention heads.
return
self
.
hf_config
.
num_attention_heads
return
self
.
hf_
text_
config
.
num_attention_heads
def
get_num_kv_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_num_kv_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
"""Returns the number of KV heads per GPU."""
"""Returns the number of KV heads per GPU."""
...
@@ -303,7 +306,7 @@ class ModelConfig:
...
@@ -303,7 +306,7 @@ class ModelConfig:
total_num_kv_heads
//
parallel_config
.
tensor_parallel_size
)
total_num_kv_heads
//
parallel_config
.
tensor_parallel_size
)
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
total_num_hidden_layers
=
self
.
hf_
text_
config
.
num_hidden_layers
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
...
@@ -627,6 +630,48 @@ class LoRAConfig:
...
@@ -627,6 +630,48 @@ class LoRAConfig:
"LoRA is enabled."
)
"LoRA is enabled."
)
@
dataclass
class
VisionLanguageConfig
:
"""Configs the input data format and how models should run for
vision language models."""
class
ImageInputType
(
enum
.
Enum
):
"""Image input type into the vision language model.
An image roughly goes through the following transformation:
Raw image --> pixel values --> image features --> image embeddings.
The difference between different image input types is where the
image encoder (pixel values --> image features) is run.
Different image input types also correspond to different tensor shapes.
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
IMAGE_FEATURES: (1, 576, 1024).
"""
PIXEL_VALUES
=
enum
.
auto
()
IMAGE_FEATURES
=
enum
.
auto
()
image_input_type
:
ImageInputType
# The input id corresponding to image token.
image_token_id
:
int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape
:
tuple
image_feature_size
:
int
@
classmethod
def
get_image_input_enum_type
(
cls
,
value
:
str
)
->
"VisionLanguageConfig.ImageInputType"
:
"""Get the image input type from a string."""
try
:
return
cls
.
ImageInputType
[
value
.
upper
()]
except
KeyError
as
e
:
raise
ValueError
(
f
"
{
value
}
is not a valid choice. "
f
"Expecting to choose from "
f
"
{
[
x
.
name
for
x
in
cls
.
ImageInputType
]
}
."
)
from
e
_STR_DTYPE_TO_TORCH_DTYPE
=
{
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
float16
,
"half"
:
torch
.
float16
,
"float16"
:
torch
.
float16
,
"float16"
:
torch
.
float16
,
...
...
vllm/core/scheduler.py
View file @
64172a97
...
@@ -388,6 +388,12 @@ class Scheduler:
...
@@ -388,6 +388,12 @@ class Scheduler:
computed_block_nums
=
self
.
block_manager
.
computed_block_nums
=
self
.
block_manager
.
get_common_computed_block_ids
(
seq_group
),
get_common_computed_block_ids
(
seq_group
),
state
=
seq_group
.
state
,
state
=
seq_group
.
state
,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
prompt_run
else
None
,
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
return
seq_group_metadata_list
,
scheduler_outputs
...
...
vllm/engine/arg_utils.py
View file @
64172a97
...
@@ -4,7 +4,9 @@ from dataclasses import dataclass
...
@@ -4,7 +4,9 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
TokenizerPoolConfig
)
ParallelConfig
,
SchedulerConfig
,
TokenizerPoolConfig
,
VisionLanguageConfig
)
from
vllm.utils
import
str_to_int_tuple
@
dataclass
@
dataclass
...
@@ -50,6 +52,11 @@ class EngineArgs:
...
@@ -50,6 +52,11 @@ class EngineArgs:
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
device
:
str
=
'auto'
ray_workers_use_nsight
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
# Related to Vision-language models such as llava
image_input_type
:
Optional
[
str
]
=
None
image_token_id
:
Optional
[
int
]
=
None
image_input_shape
:
Optional
[
str
]
=
None
image_feature_size
:
Optional
[
int
]
=
None
scheduler_delay_factor
:
float
=
0.0
scheduler_delay_factor
:
float
=
0.0
def
__post_init__
(
self
):
def
__post_init__
(
self
):
...
@@ -305,6 +312,31 @@ class EngineArgs:
...
@@ -305,6 +312,31 @@ class EngineArgs:
default
=
EngineArgs
.
device
,
default
=
EngineArgs
.
device
,
choices
=
[
"auto"
,
"cuda"
,
"neuron"
],
choices
=
[
"auto"
,
"cuda"
,
"neuron"
],
help
=
'Device type for vLLM execution.'
)
help
=
'Device type for vLLM execution.'
)
# Related to Vision-language models such as llava
parser
.
add_argument
(
'--image-input-type'
,
type
=
str
,
default
=
None
,
choices
=
[
t
.
name
.
lower
()
for
t
in
VisionLanguageConfig
.
ImageInputType
],
help
=
(
'The image input type passed into vLLM. '
'Should be one of "pixel_values" or "image_features".'
))
parser
.
add_argument
(
'--image-token-id'
,
type
=
int
,
default
=
None
,
help
=
(
'Input id for image token.'
))
parser
.
add_argument
(
'--image-input-shape'
,
type
=
str
,
default
=
None
,
help
=
(
'The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM
\'
s profile_run.'
))
parser
.
add_argument
(
'--image-feature-size'
,
type
=
int
,
default
=
None
,
help
=
(
'The image feature size along the context dimension.'
))
parser
.
add_argument
(
parser
.
add_argument
(
'--scheduler-delay-factor'
,
'--scheduler-delay-factor'
,
type
=
float
,
type
=
float
,
...
@@ -324,7 +356,8 @@ class EngineArgs:
...
@@ -324,7 +356,8 @@ class EngineArgs:
def
create_engine_configs
(
def
create_engine_configs
(
self
,
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
,
DeviceConfig
,
Optional
[
LoRAConfig
]]:
DeviceConfig
,
Optional
[
LoRAConfig
],
Optional
[
VisionLanguageConfig
]]:
device_config
=
DeviceConfig
(
self
.
device
)
device_config
=
DeviceConfig
(
self
.
device
)
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
...
@@ -358,8 +391,25 @@ class EngineArgs:
...
@@ -358,8 +391,25 @@ class EngineArgs:
lora_dtype
=
self
.
lora_dtype
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
if
self
.
image_input_type
:
if
(
not
self
.
image_token_id
or
not
self
.
image_input_shape
or
not
self
.
image_feature_size
):
raise
ValueError
(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.'
)
vision_language_config
=
VisionLanguageConfig
(
image_input_type
=
VisionLanguageConfig
.
get_image_input_enum_type
(
self
.
image_input_type
),
image_token_id
=
self
.
image_token_id
,
image_input_shape
=
str_to_int_tuple
(
self
.
image_input_shape
),
image_feature_size
=
self
.
image_feature_size
,
)
else
:
vision_language_config
=
None
return
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
return
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
)
device_config
,
lora_config
,
vision_language_config
)
@
dataclass
@
dataclass
...
...
vllm/engine/async_llm_engine.py
View file @
64172a97
...
@@ -15,6 +15,7 @@ from vllm.logger import init_logger
...
@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
int
(
ENGINE_ITERATION_TIMEOUT_S
=
int
(
...
@@ -240,6 +241,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -240,6 +241,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
)
->
None
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
...
@@ -252,14 +254,13 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -252,14 +254,13 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
lora_request
=
lora_request
)
lora_request
=
lora_request
)
return
self
.
add_request
(
return
self
.
add_request
(
request_id
,
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
)
async
def
check_health_async
(
self
)
->
None
:
async
def
check_health_async
(
self
)
->
None
:
self
.
model_executor
.
check_health
()
self
.
model_executor
.
check_health
()
...
@@ -486,6 +487,7 @@ class AsyncLLMEngine:
...
@@ -486,6 +487,7 @@ class AsyncLLMEngine:
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
AsyncStream
:
)
->
AsyncStream
:
if
self
.
log_requests
:
if
self
.
log_requests
:
shortened_prompt
=
prompt
shortened_prompt
=
prompt
...
@@ -534,7 +536,9 @@ class AsyncLLMEngine:
...
@@ -534,7 +536,9 @@ class AsyncLLMEngine:
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
)
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
,
)
return
stream
return
stream
...
@@ -545,6 +549,7 @@ class AsyncLLMEngine:
...
@@ -545,6 +549,7 @@ class AsyncLLMEngine:
request_id
:
str
,
request_id
:
str
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
)
->
AsyncIterator
[
RequestOutput
]:
)
->
AsyncIterator
[
RequestOutput
]:
"""Generate outputs for a request.
"""Generate outputs for a request.
...
@@ -560,6 +565,7 @@ class AsyncLLMEngine:
...
@@ -560,6 +565,7 @@ class AsyncLLMEngine:
prompt_token_ids: The token IDs of the prompt. If None, we
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
Yields:
The output `RequestOutput` objects from the LLMEngine for the
The output `RequestOutput` objects from the LLMEngine for the
...
@@ -619,6 +625,7 @@ class AsyncLLMEngine:
...
@@ -619,6 +625,7 @@ class AsyncLLMEngine:
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
,
)
)
async
for
request_output
in
stream
:
async
for
request_output
in
stream
:
...
...
vllm/engine/llm_engine.py
View file @
64172a97
...
@@ -5,7 +5,7 @@ from transformers import PreTrainedTokenizer
...
@@ -5,7 +5,7 @@ from transformers import PreTrainedTokenizer
import
vllm
import
vllm
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
StatLogger
,
Stats
from
vllm.engine.metrics
import
StatLogger
,
Stats
...
@@ -15,8 +15,9 @@ from vllm.logger import init_logger
...
@@ -15,8 +15,9 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
get_tokenizer_group
)
get_tokenizer_group
)
...
@@ -62,6 +63,7 @@ class LLMEngine:
...
@@ -62,6 +63,7 @@ class LLMEngine:
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
"VisionLanguageConfig"
],
executor_class
:
Type
[
ExecutorBase
],
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
log_stats
:
bool
,
)
->
None
:
)
->
None
:
...
@@ -90,6 +92,7 @@ class LLMEngine:
...
@@ -90,6 +92,7 @@ class LLMEngine:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
...
@@ -102,7 +105,8 @@ class LLMEngine:
...
@@ -102,7 +105,8 @@ class LLMEngine:
self
.
model_executor
=
executor_class
(
model_config
,
cache_config
,
self
.
model_executor
=
executor_class
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
)
device_config
,
lora_config
,
vision_language_config
)
# Ping the tokenizer to ensure liveness if it runs in a
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
# different process.
...
@@ -170,7 +174,6 @@ class LLMEngine:
...
@@ -170,7 +174,6 @@ class LLMEngine:
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
revision
=
self
.
model_config
.
tokenizer_revision
)
revision
=
self
.
model_config
.
tokenizer_revision
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
self
.
tokenizer
:
BaseTokenizerGroup
=
get_tokenizer_group
(
self
.
tokenizer
:
BaseTokenizerGroup
=
get_tokenizer_group
(
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
...
@@ -212,6 +215,7 @@ class LLMEngine:
...
@@ -212,6 +215,7 @@ class LLMEngine:
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
)
->
None
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
...
@@ -228,6 +232,7 @@ class LLMEngine:
...
@@ -228,6 +232,7 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
Details:
- Set arrival_time to the current time if it is None.
- Set arrival_time to the current time if it is None.
...
@@ -288,7 +293,7 @@ class LLMEngine:
...
@@ -288,7 +293,7 @@ class LLMEngine:
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
arrival_time
,
lora_request
)
arrival_time
,
lora_request
,
multi_modal_data
)
# Add the sequence group to the scheduler.
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
self
.
scheduler
.
add_seq_group
(
seq_group
)
...
...
vllm/entrypoints/llm.py
View file @
64172a97
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
...
@@ -8,6 +9,7 @@ from vllm.engine.llm_engine import LLMEngine
...
@@ -8,6 +9,7 @@ from vllm.engine.llm_engine import LLMEngine
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -126,6 +128,7 @@ class LLM:
...
@@ -126,6 +128,7 @@ class LLM:
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -141,6 +144,7 @@ class LLM:
...
@@ -141,6 +144,7 @@ class LLM:
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
Returns:
A list of `RequestOutput` objects containing the generated
A list of `RequestOutput` objects containing the generated
...
@@ -160,6 +164,9 @@ class LLM:
...
@@ -160,6 +164,9 @@ class LLM:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
if
multi_modal_data
:
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
# Add requests to the engine.
# Add requests to the engine.
num_requests
=
len
(
prompts
)
if
prompts
is
not
None
else
len
(
num_requests
=
len
(
prompts
)
if
prompts
is
not
None
else
len
(
prompt_token_ids
)
prompt_token_ids
)
...
@@ -167,10 +174,17 @@ class LLM:
...
@@ -167,10 +174,17 @@ class LLM:
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
i
]
i
]
self
.
_add_request
(
prompt
,
self
.
_add_request
(
sampling_params
,
prompt
,
token_ids
,
sampling_params
,
lora_request
=
lora_request
)
token_ids
,
lora_request
=
lora_request
,
# Get ith image while maintaining the batch dim.
multi_modal_data
=
MultiModalData
(
type
=
multi_modal_data
.
type
,
data
=
multi_modal_data
.
data
[
i
].
unsqueeze
(
0
))
if
multi_modal_data
else
None
,
)
return
self
.
_run_engine
(
use_tqdm
)
return
self
.
_run_engine
(
use_tqdm
)
def
_add_request
(
def
_add_request
(
...
@@ -179,13 +193,15 @@ class LLM:
...
@@ -179,13 +193,15 @@ class LLM:
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_token_ids
:
Optional
[
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
None
:
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
self
.
llm_engine
.
add_request
(
request_id
,
prompt
,
prompt
,
sampling_params
,
sampling_params
,
prompt_token_ids
,
prompt_token_ids
,
lora_request
=
lora_request
)
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
# Initialize tqdm.
# Initialize tqdm.
...
...
vllm/executor/executor_base.py
View file @
64172a97
...
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
...
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
@@ -24,6 +24,7 @@ class ExecutorBase(ABC):
...
@@ -24,6 +24,7 @@ class ExecutorBase(ABC):
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/executor/gpu_executor.py
View file @
64172a97
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -23,6 +23,7 @@ class GPUExecutor(ExecutorBase):
...
@@ -23,6 +23,7 @@ class GPUExecutor(ExecutorBase):
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -30,6 +31,7 @@ class GPUExecutor(ExecutorBase):
...
@@ -30,6 +31,7 @@ class GPUExecutor(ExecutorBase):
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
vision_language_config
=
vision_language_config
# Instantiate the worker and load the model to GPU.
# Instantiate the worker and load the model to GPU.
self
.
_init_worker
()
self
.
_init_worker
()
...
@@ -56,6 +58,7 @@ class GPUExecutor(ExecutorBase):
...
@@ -56,6 +58,7 @@ class GPUExecutor(ExecutorBase):
rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
...
...
vllm/executor/ray_gpu_executor.py
View file @
64172a97
...
@@ -6,7 +6,7 @@ from collections import defaultdict
...
@@ -6,7 +6,7 @@ from collections import defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.engine.ray_utils
import
RayWorkerVllm
,
ray
from
vllm.engine.ray_utils
import
RayWorkerVllm
,
ray
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.utils
import
check_block_size_valid
from
vllm.executor.utils
import
check_block_size_valid
...
@@ -40,6 +40,7 @@ class RayGPUExecutor(ExecutorBase):
...
@@ -40,6 +40,7 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -47,6 +48,7 @@ class RayGPUExecutor(ExecutorBase):
...
@@ -47,6 +48,7 @@ class RayGPUExecutor(ExecutorBase):
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
vision_language_config
=
vision_language_config
assert
self
.
parallel_config
.
worker_use_ray
assert
self
.
parallel_config
.
worker_use_ray
placement_group
=
self
.
parallel_config
.
placement_group
placement_group
=
self
.
parallel_config
.
placement_group
...
@@ -181,6 +183,7 @@ class RayGPUExecutor(ExecutorBase):
...
@@ -181,6 +183,7 @@ class RayGPUExecutor(ExecutorBase):
driver_rank
,
driver_rank
,
distributed_init_method
,
distributed_init_method
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
...
...
vllm/model_executor/model_loader.py
View file @
64172a97
...
@@ -7,9 +7,14 @@ import torch.nn as nn
...
@@ -7,9 +7,14 @@ import torch.nn as nn
from
vllm.config
import
DeviceConfig
,
ModelConfig
from
vllm.config
import
DeviceConfig
,
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
initialize_dummy_weights
)
_VISION_MODEL_CLASSES
=
[
LlavaForConditionalGeneration
,
]
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
_set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
def
_set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
...
@@ -40,6 +45,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
...
@@ -40,6 +45,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
**
kwargs
)
->
nn
.
Module
:
lora_config
=
kwargs
.
get
(
"lora_config"
,
None
)
lora_config
=
kwargs
.
get
(
"lora_config"
,
None
)
vision_language_config
=
kwargs
.
get
(
"vision_language_config"
,
None
)
model_class
=
_get_model_architecture
(
model_config
)
model_class
=
_get_model_architecture
(
model_config
)
# Get the (maybe quantized) linear method.
# Get the (maybe quantized) linear method.
...
@@ -76,7 +82,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
...
@@ -76,7 +82,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
"be added in the future. If this is important to you, "
"be added in the future. If this is important to you, "
"please open an issue on github."
)
"please open an issue on github."
)
else
:
else
:
model
=
model_class
(
model_config
.
hf_config
,
linear_method
)
if
model_class
not
in
_VISION_MODEL_CLASSES
:
model
=
model_class
(
model_config
.
hf_config
,
linear_method
)
else
:
model
=
model_class
(
model_config
.
hf_config
,
vision_language_config
,
linear_method
)
if
model_config
.
load_format
==
"dummy"
:
if
model_config
.
load_format
==
"dummy"
:
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment