Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d58268c5
Unverified
Commit
d58268c5
authored
Nov 06, 2024
by
Joe Runde
Committed by
GitHub
Nov 06, 2024
Browse files
[V1] Make v1 more testable (#9888)
Signed-off-by:
Joe Runde
<
Joseph.Runde@ibm.com
>
parent
87bd7e05
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
124 additions
and
49 deletions
+124
-49
Dockerfile
Dockerfile
+3
-0
pyproject.toml
pyproject.toml
+1
-0
tests/conftest.py
tests/conftest.py
+18
-0
tests/entrypoints/llm/test_prompt_validation.py
tests/entrypoints/llm/test_prompt_validation.py
+9
-0
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+2
-0
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+2
-2
vllm/attention/selector.py
vllm/attention/selector.py
+33
-10
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+10
-8
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+17
-9
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+9
-0
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+2
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+2
-2
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+2
-2
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+2
-2
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+2
-2
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+2
-2
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+2
-2
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+2
-2
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+2
-2
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+2
-2
No files found.
Dockerfile
View file @
d58268c5
...
@@ -191,6 +191,9 @@ ADD . /vllm-workspace/
...
@@ -191,6 +191,9 @@ ADD . /vllm-workspace/
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
python3
-m
pip
install
-r
requirements-dev.txt
python3
-m
pip
install
-r
requirements-dev.txt
# Copy in the v1 package for testing (it isn't distributed yet)
COPY
vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
# doc requires source code
# doc requires source code
# we hide them inside `test_docs/` , so that this source code
# we hide them inside `test_docs/` , so that this source code
# will not be imported by other tests
# will not be imported by other tests
...
...
pyproject.toml
View file @
d58268c5
...
@@ -97,4 +97,5 @@ markers = [
...
@@ -97,4 +97,5 @@ markers = [
"skip_global_cleanup"
,
"skip_global_cleanup"
,
"core_model: run this model test in each PR instead of just daily"
,
"core_model: run this model test in each PR instead of just daily"
,
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs"
,
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs"
,
"skip_v1: do not run this test with v1"
,
]
]
tests/conftest.py
View file @
d58268c5
...
@@ -5,6 +5,7 @@ from collections import UserList
...
@@ -5,6 +5,7 @@ from collections import UserList
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypedDict
,
TypeVar
,
Union
)
TypedDict
,
TypeVar
,
Union
)
from
unittest.mock
import
patch
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -108,6 +109,23 @@ VIDEO_ASSETS = _VideoAssets()
...
@@ -108,6 +109,23 @@ VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
"""Singleton instance of :class:`_VideoAssets`."""
@
pytest
.
fixture
(
params
=
[
True
,
False
])
def
run_with_both_engines
(
request
):
# Automatically runs tests twice, once with V1 and once without
use_v1
=
request
.
param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1
=
request
.
node
.
get_closest_marker
(
"skip_v1"
)
if
use_v1
:
if
skip_v1
:
pytest
.
skip
(
"Skipping test on vllm V1"
)
with
patch
(
'vllm.envs.VLLM_USE_V1'
,
True
):
yield
else
:
with
patch
(
'vllm.envs.VLLM_USE_V1'
,
False
):
yield
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
init_test_http_connection
():
def
init_test_http_connection
():
# pytest_asyncio may use a different event loop per test
# pytest_asyncio may use a different event loop per test
...
...
tests/entrypoints/llm/test_prompt_validation.py
View file @
d58268c5
...
@@ -3,12 +3,21 @@ import pytest
...
@@ -3,12 +3,21 @@ import pytest
from
vllm
import
LLM
from
vllm
import
LLM
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines
):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
def
test_empty_prompt
():
def
test_empty_prompt
():
llm
=
LLM
(
model
=
"gpt2"
,
enforce_eager
=
True
)
llm
=
LLM
(
model
=
"gpt2"
,
enforce_eager
=
True
)
with
pytest
.
raises
(
ValueError
,
match
=
'Prompt cannot be empty'
):
with
pytest
.
raises
(
ValueError
,
match
=
'Prompt cannot be empty'
):
llm
.
generate
([
""
])
llm
.
generate
([
""
])
@
pytest
.
mark
.
skip_v1
def
test_out_of_vocab_token
():
def
test_out_of_vocab_token
():
llm
=
LLM
(
model
=
"gpt2"
,
enforce_eager
=
True
)
llm
=
LLM
(
model
=
"gpt2"
,
enforce_eager
=
True
)
with
pytest
.
raises
(
ValueError
,
match
=
'out of vocabulary'
):
with
pytest
.
raises
(
ValueError
,
match
=
'out of vocabulary'
):
...
...
tests/kernels/test_attention_selector.py
View file @
d58268c5
...
@@ -44,6 +44,8 @@ def test_env(name: str, device: str, monkeypatch):
...
@@ -44,6 +44,8 @@ def test_env(name: str, device: str, monkeypatch):
def
test_flash_attn
(
monkeypatch
):
def
test_flash_attn
(
monkeypatch
):
"""Test FlashAttn validation."""
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use
override_backend_env_variable
(
monkeypatch
,
STR_FLASH_ATTN_VAL
)
override_backend_env_variable
(
monkeypatch
,
STR_FLASH_ATTN_VAL
)
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
d58268c5
...
@@ -16,7 +16,7 @@ from tests.kernels.utils import *
...
@@ -16,7 +16,7 @@ from tests.kernels.utils import *
from
vllm.attention
import
(
Attention
,
AttentionBackend
,
AttentionMetadata
,
from
vllm.attention
import
(
Attention
,
AttentionBackend
,
AttentionMetadata
,
AttentionType
)
AttentionType
)
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
get_attn_backend
,
from
vllm.attention.selector
import
(
_Backend
,
_cached_
get_attn_backend
,
global_force_attn_backend_context_manager
)
global_force_attn_backend_context_manager
)
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -774,7 +774,7 @@ def set_reset_environment(attn_backend):
...
@@ -774,7 +774,7 @@ def set_reset_environment(attn_backend):
default_dtype
=
torch
.
get_default_dtype
()
default_dtype
=
torch
.
get_default_dtype
()
if
attn_backend
.
name
==
'FLASH_ATTN'
:
if
attn_backend
.
name
==
'FLASH_ATTN'
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
get_attn_backend
.
cache_clear
()
_cached_
get_attn_backend
.
cache_clear
()
yield
yield
# Reset the torch datatype to what it was before the test
# Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests.
# so as not to impact the remaining tests.
...
...
vllm/attention/selector.py
View file @
d58268c5
...
@@ -89,7 +89,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
...
@@ -89,7 +89,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return
forced_attn_backend
return
forced_attn_backend
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
def
get_attn_backend
(
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -99,6 +98,31 @@ def get_attn_backend(
...
@@ -99,6 +98,31 @@ def get_attn_backend(
is_blocksparse
:
bool
=
False
,
is_blocksparse
:
bool
=
False
,
)
->
Type
[
AttentionBackend
]:
)
->
Type
[
AttentionBackend
]:
"""Selects which attention backend to use and lazily imports it."""
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return
_cached_get_attn_backend
(
head_size
=
head_size
,
dtype
=
dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
block_size
=
block_size
,
is_attention_free
=
is_attention_free
,
is_blocksparse
=
is_blocksparse
,
use_v1
=
envs
.
VLLM_USE_V1
,
)
@
lru_cache
(
maxsize
=
None
)
def
_cached_get_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
is_attention_free
:
bool
,
is_blocksparse
:
bool
=
False
,
use_v1
:
bool
=
False
,
)
->
Type
[
AttentionBackend
]:
if
is_blocksparse
:
if
is_blocksparse
:
logger
.
info
(
"Using BlocksparseFlashAttention backend."
)
logger
.
info
(
"Using BlocksparseFlashAttention backend."
)
from
vllm.attention.backends.blocksparse_attn
import
(
from
vllm.attention.backends.blocksparse_attn
import
(
...
@@ -106,7 +130,7 @@ def get_attn_backend(
...
@@ -106,7 +130,7 @@ def get_attn_backend(
return
BlocksparseFlashAttentionBackend
return
BlocksparseFlashAttentionBackend
backend
=
which_attn_to_use
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
backend
=
which_attn_to_use
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
is_attention_free
)
is_attention_free
,
use_v1
)
if
backend
==
_Backend
.
FLASH_ATTN
:
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using Flash Attention backend."
)
logger
.
info
(
"Using Flash Attention backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
...
@@ -162,13 +186,12 @@ def get_attn_backend(
...
@@ -162,13 +186,12 @@ def get_attn_backend(
raise
ValueError
(
"Invalid attention backend."
)
raise
ValueError
(
"Invalid attention backend."
)
def
which_attn_to_use
(
def
which_attn_to_use
(
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
block_size
:
int
,
is_attention_free
:
bool
,
is_attention_free
:
bool
,
use_v1
:
bool
=
False
)
->
_Backend
:
)
->
_Backend
:
"""Returns which flash attention backend to use."""
"""Returns which flash attention backend to use."""
# Default case.
# Default case.
selected_backend
=
_Backend
.
FLASH_ATTN
selected_backend
=
_Backend
.
FLASH_ATTN
...
@@ -228,7 +251,7 @@ def which_attn_to_use(
...
@@ -228,7 +251,7 @@ def which_attn_to_use(
if
current_platform
.
is_hpu
():
if
current_platform
.
is_hpu
():
return
_Backend
.
HPU_ATTN
return
_Backend
.
HPU_ATTN
if
envs
.
VLLM_USE_V
1
:
if
use_v
1
:
return
_Backend
.
FLASH_ATTN_VLLM_V1
return
_Backend
.
FLASH_ATTN_VLLM_V1
# FlashAttn in NVIDIA GPUs.
# FlashAttn in NVIDIA GPUs.
...
...
vllm/engine/multiprocessing/engine.py
View file @
d58268c5
...
@@ -6,7 +6,9 @@ from typing import Iterator, List, Optional, Union
...
@@ -6,7 +6,9 @@ from typing import Iterator, List, Optional, Union
import
cloudpickle
import
cloudpickle
import
zmq
import
zmq
import
vllm.envs
from
vllm
import
AsyncEngineArgs
,
SamplingParams
from
vllm
import
AsyncEngineArgs
,
SamplingParams
from
vllm.engine.llm_engine
import
LLMEngine
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
...
@@ -17,17 +19,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -17,17 +19,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest
,
RPCStartupResponse
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
RPCUProfileRequest
)
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_USE_V1
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
if
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
else
:
from
vllm.engine.llm_engine
import
LLMEngine
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
POLLING_TIMEOUT_MS
=
10000
POLLING_TIMEOUT_MS
=
10000
...
@@ -117,11 +113,17 @@ class MQLLMEngine:
...
@@ -117,11 +113,17 @@ class MQLLMEngine:
load_general_plugins
()
load_general_plugins
()
engine_config
=
engine_args
.
create_engine_config
()
engine_config
=
engine_args
.
create_engine_config
()
if
vllm
.
envs
.
VLLM_USE_V1
:
# Lazy import: the v1 package isn't distributed
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
engine_class
=
V1LLMEngine
else
:
engine_class
=
LLMEngine
executor_class
=
LLME
ngine
.
_get_executor_cls
(
engine_config
)
executor_class
=
e
ngine
_class
.
_get_executor_cls
(
engine_config
)
use_async_sockets
=
(
engine_config
.
model_config
.
use_async_output_proc
use_async_sockets
=
(
engine_config
.
model_config
.
use_async_output_proc
and
not
VLLM_USE_V1
)
and
not
vllm
.
envs
.
VLLM_USE_V1
)
return
cls
(
ipc_path
=
ipc_path
,
return
cls
(
ipc_path
=
ipc_path
,
use_async_sockets
=
use_async_sockets
,
use_async_sockets
=
use_async_sockets
,
...
...
vllm/entrypoints/llm.py
View file @
d58268c5
import
itertools
import
itertools
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
Union
,
cast
,
overload
)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -10,6 +10,7 @@ from vllm import envs
...
@@ -10,6 +10,7 @@ from vllm import envs
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
BeamSearchSequence
,
get_beam_search_score
)
BeamSearchSequence
,
get_beam_search_score
)
from
vllm.engine.arg_utils
import
EngineArgs
,
TaskOption
from
vllm.engine.arg_utils
import
EngineArgs
,
TaskOption
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_hf_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
apply_mistral_chat_template
,
...
@@ -31,11 +32,6 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...
@@ -31,11 +32,6 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
from
vllm.utils
import
Counter
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
if
envs
.
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
# type: ignore
else
:
from
vllm.engine.llm_engine
import
LLMEngine
# type: ignore
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -206,10 +202,21 @@ class LLM:
...
@@ -206,10 +202,21 @@ class LLM:
pooling_returned_token_ids
=
pooling_returned_token_ids
,
pooling_returned_token_ids
=
pooling_returned_token_ids
,
**
kwargs
,
**
kwargs
,
)
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self
.
engine_class
=
self
.
get_engine_class
()
self
.
llm_engine
=
self
.
engine_class
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
self
.
request_counter
=
Counter
()
@
staticmethod
def
get_engine_class
()
->
Type
[
LLMEngine
]:
if
envs
.
VLLM_USE_V1
:
# Lazy import: the v1 package isn't distributed
from
vllm.v1.engine.llm_engine
import
LLMEngine
as
V1LLMEngine
return
V1LLMEngine
# type: ignore
return
LLMEngine
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
return
self
.
llm_engine
.
get_tokenizer_group
(
TokenizerGroup
).
tokenizer
return
self
.
llm_engine
.
get_tokenizer_group
(
TokenizerGroup
).
tokenizer
...
@@ -394,7 +401,7 @@ class LLM:
...
@@ -394,7 +401,7 @@ class LLM:
priority
=
priority
)
priority
=
priority
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
beam_search
(
def
beam_search
(
self
,
self
,
...
@@ -769,7 +776,8 @@ class LLM:
...
@@ -769,7 +776,8 @@ class LLM:
)
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
EmbeddingRequestOutput
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
EmbeddingRequestOutput
)
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
self
.
llm_engine
.
start_profile
()
...
...
vllm/model_executor/layers/sampler.py
View file @
d58268c5
...
@@ -30,6 +30,15 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
...
@@ -30,6 +30,15 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
else
:
else
:
flashinfer_top_k_top_p_sampling
=
None
flashinfer_top_k_top_p_sampling
=
None
def
get_sampler
()
->
torch
.
nn
.
Module
:
if
envs
.
VLLM_USE_V1
:
# Lazy import: the v1 package isn't distributed
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
return
V1Sampler
()
return
Sampler
()
# (num_token_ids, num_parent_ids) per sequence group.
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
...
...
vllm/model_executor/models/arctic.py
View file @
d58268c5
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
,
DeepSpeedFPParameter
)
DeepSpeedFPConfig
,
DeepSpeedFPParameter
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -436,7 +436,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
...
@@ -436,7 +436,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/baichuan.py
View file @
d58268c5
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -352,7 +352,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -352,7 +352,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/bart.py
View file @
d58268c5
...
@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -838,7 +838,7 @@ class BartForConditionalGeneration(nn.Module):
...
@@ -838,7 +838,7 @@ class BartForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/blip2.py
View file @
d58268c5
...
@@ -13,7 +13,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
...
@@ -13,7 +13,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext
,
token_inputs
)
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
consecutive_placeholder_ranges
from
vllm.multimodal.utils
import
consecutive_placeholder_ranges
...
@@ -525,7 +525,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -525,7 +525,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if
hasattr
(
self
.
language_model
,
"sampler"
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
self
.
language_model
.
sampler
return
S
ampler
()
return
get_s
ampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
h
=
w
=
self
.
config
.
vision_config
.
image_size
...
...
vllm/model_executor/models/bloom.py
View file @
d58268c5
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -298,7 +298,7 @@ class BloomForCausalLM(nn.Module, SupportsPP):
...
@@ -298,7 +298,7 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self
.
config
.
hidden_size
)
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
self
.
transformer
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/chameleon.py
View file @
d58268c5
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -946,7 +946,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -946,7 +946,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/chatglm.py
View file @
d58268c5
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -616,7 +616,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
...
@@ -616,7 +616,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self
.
transformer
.
embedding
.
weight
)
self
.
transformer
.
embedding
.
weight
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/commandr.py
View file @
d58268c5
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -355,7 +355,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -355,7 +355,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
cache_config
,
cache_config
,
quant_config
,
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/dbrx.py
View file @
d58268c5
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -373,7 +373,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
...
@@ -373,7 +373,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
)
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
self
.
transformer
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/deepseek.py
View file @
d58268c5
...
@@ -41,7 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -41,7 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
Output
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -399,7 +399,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
...
@@ -399,7 +399,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
if
self
.
config
.
tie_word_embeddings
:
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
S
ampler
()
self
.
sampler
=
get_s
ampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment