Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
448
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1084 additions
and
119 deletions
+1084
-119
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+35
-1
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+7
-4
tests/tensorizer_loader/conftest.py
tests/tensorizer_loader/conftest.py
+12
-4
tests/test_inputs.py
tests/test_inputs.py
+1
-1
tests/test_logger.py
tests/test_logger.py
+2
-1
tests/test_logits_processor.py
tests/test_logits_processor.py
+10
-4
tests/test_sequence.py
tests/test_sequence.py
+5
-2
tests/test_utils.py
tests/test_utils.py
+10
-23
tests/tracing/test_tracing.py
tests/tracing/test_tracing.py
+68
-0
tests/utils.py
tests/utils.py
+40
-23
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+20
-0
tests/weight_loading/run_model_weight_loading_test.sh
tests/weight_loading/run_model_weight_loading_test.sh
+32
-0
tests/weight_loading/test_weight_loading.py
tests/weight_loading/test_weight_loading.py
+20
-0
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+486
-0
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+83
-1
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+20
-14
vllm/_core_ext.py
vllm/_core_ext.py
+119
-24
vllm/_custom_ops.py
vllm/_custom_ops.py
+88
-15
vllm/adapter_commons/request.py
vllm/adapter_commons/request.py
+0
-2
vllm/assets/audio.py
vllm/assets/audio.py
+26
-0
No files found.
Too many changes to show.
To preserve performance only
448 of 448+
files are displayed.
Plain diff
Email patch
tests/spec_decode/test_multi_step_worker.py
View file @
af7f4372
...
@@ -6,7 +6,8 @@ import pytest
...
@@ -6,7 +6,8 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
Logprob
,
SamplerOutput
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
Logprob
,
SamplerOutput
,
get_all_seq_ids
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
...
@@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step():
...
@@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step():
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
call_args_list
=
worker
.
model_runner
.
_gpu_advance_step
.
call_args_list
call_args_list
=
worker
.
model_runner
.
_gpu_advance_step
.
call_args_list
assert
len
(
call_args_list
)
==
1
assert
len
(
call_args_list
)
==
1
@
torch
.
inference_mode
()
def
test_expand_execute_model_request_sync_with_expand_hidden_states
():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k
=
5
batch_size
=
16
seq_with_bonus_token_in_last_step
=
[
1
,
3
,
8
,
10
,
13
,
15
]
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
execute_model_request
=
ExecuteModelRequest
(
seq_group_metadata_list
,
previous_hidden_states
=
HiddenStates
(
torch
.
arange
(
batch_size
),
seq_group_metadata_list
,
torch
.
arange
(
batch_size
,
2
*
batch_size
)))
expanded_execute_model_request
,
orig_seq_group_ids
=
MultiStepWorker
.
\
_expand_execute_model_request
(
execute_model_request
,
seq_with_bonus_token_in_last_step
)
all_seq_ids
=
torch
.
tensor
(
get_all_seq_ids
(
expanded_execute_model_request
.
seq_group_metadata_list
))
ref_expanded_hidden_states
=
all_seq_ids
+
batch_size
ref_expanded_hidden_states
[
orig_seq_group_ids
]
-=
batch_size
assert
(
ref_expanded_hidden_states
==
expanded_execute_model_request
.
previous_hidden_states
.
hidden_states
).
all
().
item
()
tests/spec_decode/utils.py
View file @
af7f4372
from
array
import
array
from
itertools
import
count
from
itertools
import
count
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -9,7 +10,8 @@ import torch
...
@@ -9,7 +10,8 @@ import torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
...
@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
...
@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data
=
{
seq_data
=
{
i
:
i
:
SequenceData
(
SequenceData
(
prompt_token_ids
=
prompt_token_ids
[:],
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
[:]),
output_token_ids
=
cont_token_ids
[:],
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
cont_token_ids
[:]),
),
),
},
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
...
@@ -161,7 +164,7 @@ def assert_logprobs_dict_allclose(
...
@@ -161,7 +164,7 @@ def assert_logprobs_dict_allclose(
single_step_actual_logprobs
[
token_id
].
logprob
)
single_step_actual_logprobs
[
token_id
].
logprob
)
expected
=
torch
.
tensor
(
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
].
logprob
)
single_step_expected_logprobs
[
token_id
].
logprob
)
assert
torch
.
all
close
(
actual
,
expected
)
torch
.
testing
.
assert_
close
(
actual
,
expected
)
def
create_sampler_output_list
(
def
create_sampler_output_list
(
...
...
tests/tensorizer_loader/conftest.py
View file @
af7f4372
import
contextlib
import
contextlib
import
functools
import
functools
import
gc
import
gc
from
typing
import
Callable
,
TypeVar
import
pytest
import
pytest
import
ray
import
ray
import
torch
import
torch
from
typing_extensions
import
ParamSpec
from
vllm.distributed
import
(
destroy_distributed_environment
,
from
vllm.distributed
import
(
destroy_distributed_environment
,
destroy_model_parallel
)
destroy_model_parallel
)
...
@@ -22,12 +24,16 @@ def cleanup():
...
@@ -22,12 +24,16 @@ def cleanup():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
retry_until_skip
(
n
):
_P
=
ParamSpec
(
"_P"
)
_R
=
TypeVar
(
"_R"
)
def
decorator_retry
(
func
):
def
retry_until_skip
(
n
:
int
):
def
decorator_retry
(
func
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapper_retry
(
*
args
,
**
kwargs
)
:
def
wrapper_retry
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
for
i
in
range
(
n
):
for
i
in
range
(
n
):
try
:
try
:
return
func
(
*
args
,
**
kwargs
)
return
func
(
*
args
,
**
kwargs
)
...
@@ -35,7 +41,9 @@ def retry_until_skip(n):
...
@@ -35,7 +41,9 @@ def retry_until_skip(n):
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
i
==
n
-
1
:
if
i
==
n
-
1
:
pytest
.
skip
(
"Skipping test after attempts.."
)
pytest
.
skip
(
f
"Skipping test after
{
n
}
attempts."
)
raise
AssertionError
(
"Code should not be reached"
)
return
wrapper_retry
return
wrapper_retry
...
...
tests/test_inputs.py
View file @
af7f4372
...
@@ -2,7 +2,7 @@ from typing import List
...
@@ -2,7 +2,7 @@ from typing import List
import
pytest
import
pytest
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.inputs
.parse
import
parse_and_batch_prompt
STRING_INPUTS
=
[
STRING_INPUTS
=
[
''
,
''
,
...
...
tests/test_logger.py
View file @
af7f4372
...
@@ -49,7 +49,8 @@ def test_default_vllm_root_logger_configuration():
...
@@ -49,7 +49,8 @@ def test_default_vllm_root_logger_configuration():
handler
=
logger
.
handlers
[
0
]
handler
=
logger
.
handlers
[
0
]
assert
isinstance
(
handler
,
logging
.
StreamHandler
)
assert
isinstance
(
handler
,
logging
.
StreamHandler
)
assert
handler
.
stream
==
sys
.
stdout
assert
handler
.
stream
==
sys
.
stdout
assert
handler
.
level
==
logging
.
INFO
# we use DEBUG level for testing by default
# assert handler.level == logging.INFO
formatter
=
handler
.
formatter
formatter
=
handler
.
formatter
assert
formatter
is
not
None
assert
formatter
is
not
None
...
...
tests/test_logits_processor.py
View file @
af7f4372
import
random
import
random
from
array
import
array
from
typing
import
Tuple
from
typing
import
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -8,7 +9,8 @@ import torch
...
@@ -8,7 +9,8 @@ import torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
...
@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
temperature
=
0
,
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
...
@@ -90,5 +94,7 @@ def test_logits_processors(seed: int, device: str):
...
@@ -90,5 +94,7 @@ def test_logits_processors(seed: int, device: str):
assert
torch
.
isinf
(
logits_processor_output
[:,
0
]).
all
()
assert
torch
.
isinf
(
logits_processor_output
[:,
0
]).
all
()
fake_logits
*=
logits_processor
.
scale
fake_logits
*=
logits_processor
.
scale
assert
torch
.
allclose
(
logits_processor_output
[:,
1
],
fake_logits
[:,
1
],
torch
.
testing
.
assert_close
(
logits_processor_output
[:,
1
],
1e-4
)
fake_logits
[:,
1
],
rtol
=
1e-4
,
atol
=
0.0
)
tests/test_sequence.py
View file @
af7f4372
from
array
import
array
import
pytest
import
pytest
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SamplerOutput
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
SamplerOutput
,
SequenceData
,
SequenceOutput
)
SequenceData
,
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
from
.core.utils
import
create_dummy_prompt
...
@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
...
@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
def
test_sequence_data_prefill
():
def
test_sequence_data_prefill
():
seq_data
=
SequenceData
(
prompt_token_ids
=
[
1
,
2
,
3
,
4
])
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
,
4
])
)
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_computed_tokens
()
==
0
assert
seq_data
.
get_num_computed_tokens
()
==
0
# advance by 2
# advance by 2
...
...
tests/test_utils.py
View file @
af7f4372
import
asyncio
import
asyncio
import
os
import
os
import
socket
import
socket
import
sys
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncIterator
,
Awaitable
,
Protocol
,
from
typing
import
AsyncIterator
,
Tuple
Tuple
,
TypeVar
)
import
pytest
import
pytest
...
@@ -12,36 +11,23 @@ from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
...
@@ -12,36 +11,23 @@ from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
from
.utils
import
error_on_warning
from
.utils
import
error_on_warning
if
sys
.
version_info
<
(
3
,
10
):
if
TYPE_CHECKING
:
_AwaitableT
=
TypeVar
(
"_AwaitableT"
,
bound
=
Awaitable
[
Any
])
_AwaitableT_co
=
TypeVar
(
"_AwaitableT_co"
,
bound
=
Awaitable
[
Any
],
covariant
=
True
)
class
_SupportsSynchronousAnext
(
Protocol
[
_AwaitableT_co
]):
def
__anext__
(
self
)
->
_AwaitableT_co
:
...
def
anext
(
i
:
"_SupportsSynchronousAnext[_AwaitableT]"
,
/
)
->
"_AwaitableT"
:
return
i
.
__anext__
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_merge_async_iterators
():
async
def
test_merge_async_iterators
():
async
def
mock_async_iterator
(
idx
:
int
)
->
AsyncIterator
[
str
]
:
async
def
mock_async_iterator
(
idx
:
int
):
try
:
try
:
while
True
:
while
True
:
yield
f
"item from iterator
{
idx
}
"
yield
f
"item from iterator
{
idx
}
"
await
asyncio
.
sleep
(
0.1
)
await
asyncio
.
sleep
(
0.1
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
p
ass
p
rint
(
f
"iterator
{
idx
}
cancelled"
)
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
iterators
=
[
mock_async_iterator
(
i
)
for
i
in
range
(
3
)]
merged_iterator
:
AsyncIterator
[
Tuple
[
int
,
str
]]
=
merge_async_iterators
(
merged_iterator
=
merge_async_iterators
(
*
iterators
,
*
iterators
)
is_cancelled
=
partial
(
asyncio
.
sleep
,
0
,
result
=
False
))
async
def
stream_output
(
generator
:
AsyncIterator
[
Tuple
[
int
,
str
]]):
async
def
stream_output
(
generator
:
AsyncIterator
[
Tuple
[
int
,
str
]]):
async
for
idx
,
output
in
generator
:
async
for
idx
,
output
in
generator
:
...
@@ -55,7 +41,8 @@ async def test_merge_async_iterators():
...
@@ -55,7 +41,8 @@ async def test_merge_async_iterators():
for
iterator
in
iterators
:
for
iterator
in
iterators
:
try
:
try
:
await
asyncio
.
wait_for
(
anext
(
iterator
),
1
)
# Can use anext() in python >= 3.10
await
asyncio
.
wait_for
(
iterator
.
__anext__
(),
1
)
except
StopAsyncIteration
:
except
StopAsyncIteration
:
# All iterators should be cancelled and print this message.
# All iterators should be cancelled and print this message.
print
(
"Iterator was cancelled normally"
)
print
(
"Iterator was cancelled normally"
)
...
...
tests/tracing/test_tracing.py
View file @
af7f4372
...
@@ -114,3 +114,71 @@ def test_traces(trace_service):
...
@@ -114,3 +114,71 @@ def test_traces(trace_service):
SpanAttributes
.
LLM_LATENCY_TIME_TO_FIRST_TOKEN
)
==
ttft
SpanAttributes
.
LLM_LATENCY_TIME_TO_FIRST_TOKEN
)
==
ttft
e2e_time
=
metrics
.
finished_time
-
metrics
.
arrival_time
e2e_time
=
metrics
.
finished_time
-
metrics
.
arrival_time
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_E2E
)
==
e2e_time
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_E2E
)
==
e2e_time
assert
metrics
.
scheduler_time
>
0
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_SCHEDULER
)
==
metrics
.
scheduler_time
# Model forward and model execute should be none, since detailed traces is
# not enabled.
assert
metrics
.
model_forward_time
is
None
assert
metrics
.
model_execute_time
is
None
def
test_traces_with_detailed_steps
(
trace_service
):
os
.
environ
[
OTEL_EXPORTER_OTLP_TRACES_INSECURE
]
=
"true"
sampling_params
=
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.1
,
max_tokens
=
256
)
model
=
"facebook/opt-125m"
llm
=
LLM
(
model
=
model
,
otlp_traces_endpoint
=
FAKE_TRACE_SERVER_ADDRESS
,
collect_detailed_traces
=
"all"
,
)
prompts
=
[
"This is a short prompt"
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
timeout
=
5
if
not
trace_service
.
evt
.
wait
(
timeout
):
raise
TimeoutError
(
f
"The fake trace service didn't receive a trace within "
f
"the
{
timeout
}
seconds timeout"
)
attributes
=
decode_attributes
(
trace_service
.
request
.
resource_spans
[
0
].
scope_spans
[
0
].
spans
[
0
].
attributes
)
assert
attributes
.
get
(
SpanAttributes
.
LLM_RESPONSE_MODEL
)
==
model
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_ID
)
==
outputs
[
0
].
request_id
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_TEMPERATURE
)
==
sampling_params
.
temperature
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_TOP_P
)
==
sampling_params
.
top_p
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_MAX_TOKENS
)
==
sampling_params
.
max_tokens
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_BEST_OF
)
==
sampling_params
.
best_of
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_N
)
==
sampling_params
.
n
assert
attributes
.
get
(
SpanAttributes
.
LLM_USAGE_PROMPT_TOKENS
)
==
len
(
outputs
[
0
].
prompt_token_ids
)
completion_tokens
=
sum
(
len
(
o
.
token_ids
)
for
o
in
outputs
[
0
].
outputs
)
assert
attributes
.
get
(
SpanAttributes
.
LLM_USAGE_COMPLETION_TOKENS
)
==
completion_tokens
metrics
=
outputs
[
0
].
metrics
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_QUEUE
)
==
metrics
.
time_in_queue
ttft
=
metrics
.
first_token_time
-
metrics
.
arrival_time
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_TO_FIRST_TOKEN
)
==
ttft
e2e_time
=
metrics
.
finished_time
-
metrics
.
arrival_time
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_E2E
)
==
e2e_time
assert
metrics
.
scheduler_time
>
0
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_SCHEDULER
)
==
metrics
.
scheduler_time
assert
metrics
.
model_forward_time
>
0
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_MODEL_FORWARD
)
==
pytest
.
approx
(
metrics
.
model_forward_time
/
1000
)
assert
metrics
.
model_execute_time
>
0
assert
attributes
.
get
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_MODEL_EXECUTE
)
==
metrics
.
model_execute_time
assert
metrics
.
model_forward_time
<
1000
*
metrics
.
model_execute_time
tests/utils.py
View file @
af7f4372
...
@@ -7,19 +7,20 @@ import time
...
@@ -7,19 +7,20 @@ import time
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
openai
import
openai
import
ray
import
requests
import
requests
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
typing_extensions
import
ParamSpec
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_port
,
is_hip
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_port
,
is_hip
if
is_hip
():
if
current_platform
.
is_rocm
():
from
amdsmi
import
(
amdsmi_get_gpu_vram_usage
,
from
amdsmi
import
(
amdsmi_get_gpu_vram_usage
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
)
amdsmi_shut_down
)
...
@@ -31,7 +32,7 @@ if is_hip():
...
@@ -31,7 +32,7 @@ if is_hip():
yield
yield
finally
:
finally
:
amdsmi_shut_down
()
amdsmi_shut_down
()
el
se
:
el
if
current_platform
.
is_cuda
()
:
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
nvmlInit
,
nvmlShutdown
)
nvmlInit
,
nvmlShutdown
)
...
@@ -42,6 +43,11 @@ else:
...
@@ -42,6 +43,11 @@ else:
yield
yield
finally
:
finally
:
nvmlShutdown
()
nvmlShutdown
()
else
:
@
contextmanager
def
_nvml
():
yield
VLLM_PATH
=
Path
(
__file__
).
parent
.
parent
VLLM_PATH
=
Path
(
__file__
).
parent
.
parent
...
@@ -50,16 +56,14 @@ VLLM_PATH = Path(__file__).parent.parent
...
@@ -50,16 +56,14 @@ VLLM_PATH = Path(__file__).parent.parent
class
RemoteOpenAIServer
:
class
RemoteOpenAIServer
:
DUMMY_API_KEY
=
"token-abc123"
# vLLM's OpenAI server does not need API key
DUMMY_API_KEY
=
"token-abc123"
# vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S
=
120
# wait for server to start for 120 seconds
def
__init__
(
self
,
def
__init__
(
model
:
str
,
self
,
cli_args
:
List
[
str
],
model
:
str
,
*
,
cli_args
:
List
[
str
],
env_dict
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
*
,
auto_port
:
bool
=
True
,
env_dict
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
auto_port
:
bool
=
True
,
)
->
None
:
if
auto_port
:
if
auto_port
:
if
"-p"
in
cli_args
or
"--port"
in
cli_args
:
if
"-p"
in
cli_args
or
"--port"
in
cli_args
:
raise
ValueError
(
"You have manually specified the port"
raise
ValueError
(
"You have manually specified the port"
...
@@ -84,8 +88,9 @@ class RemoteOpenAIServer:
...
@@ -84,8 +88,9 @@ class RemoteOpenAIServer:
env
=
env
,
env
=
env
,
stdout
=
sys
.
stdout
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stderr
)
stderr
=
sys
.
stderr
)
max_wait_seconds
=
max_wait_seconds
or
240
self
.
_wait_for_server
(
url
=
self
.
url_for
(
"health"
),
self
.
_wait_for_server
(
url
=
self
.
url_for
(
"health"
),
timeout
=
self
.
MAX_SERVER_START_WAIT_S
)
timeout
=
max_wait_seconds
)
def
__enter__
(
self
):
def
__enter__
(
self
):
return
self
return
self
...
@@ -139,7 +144,8 @@ def compare_two_settings(model: str,
...
@@ -139,7 +144,8 @@ def compare_two_settings(model: str,
arg1
:
List
[
str
],
arg1
:
List
[
str
],
arg2
:
List
[
str
],
arg2
:
List
[
str
],
env1
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
env1
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
env2
:
Optional
[
Dict
[
str
,
str
]]
=
None
):
env2
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
"""
"""
Launch API server with two different sets of arguments/environments
Launch API server with two different sets of arguments/environments
and compare the results of the API calls.
and compare the results of the API calls.
...
@@ -158,7 +164,10 @@ def compare_two_settings(model: str,
...
@@ -158,7 +164,10 @@ def compare_two_settings(model: str,
token_ids
=
tokenizer
(
prompt
)[
"input_ids"
]
token_ids
=
tokenizer
(
prompt
)[
"input_ids"
]
results
=
[]
results
=
[]
for
args
,
env
in
((
arg1
,
env1
),
(
arg2
,
env2
)):
for
args
,
env
in
((
arg1
,
env1
),
(
arg2
,
env2
)):
with
RemoteOpenAIServer
(
model
,
args
,
env_dict
=
env
)
as
server
:
with
RemoteOpenAIServer
(
model
,
args
,
env_dict
=
env
,
max_wait_seconds
=
max_wait_seconds
)
as
server
:
client
=
server
.
get_client
()
client
=
server
.
get_client
()
# test models list
# test models list
...
@@ -266,8 +275,9 @@ def compare_two_settings(model: str,
...
@@ -266,8 +275,9 @@ def compare_two_settings(model: str,
arg1_results
=
results
[:
n
]
arg1_results
=
results
[:
n
]
arg2_results
=
results
[
n
:]
arg2_results
=
results
[
n
:]
for
arg1_result
,
arg2_result
in
zip
(
arg1_results
,
arg2_results
):
for
arg1_result
,
arg2_result
in
zip
(
arg1_results
,
arg2_results
):
assert
arg1_result
==
arg2_result
,
\
assert
arg1_result
==
arg2_result
,
(
f
"Results for
{
model
=
}
are not the same with
{
arg1
=
}
and
{
arg2
=
}
"
f
"Results for
{
model
=
}
are not the same with
{
arg1
=
}
and
{
arg2
=
}
. "
f
"
{
arg1_result
=
}
!=
{
arg2_result
=
}
"
)
def
init_test_distributed_environment
(
def
init_test_distributed_environment
(
...
@@ -291,6 +301,8 @@ def multi_process_parallel(
...
@@ -291,6 +301,8 @@ def multi_process_parallel(
pp_size
:
int
,
pp_size
:
int
,
test_target
:
Any
,
test_target
:
Any
,
)
->
None
:
)
->
None
:
import
ray
# Using ray helps debugging the error when it failed
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# NOTE: We need to set working_dir for distributed tests,
...
@@ -359,18 +371,23 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
...
@@ -359,18 +371,23 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
time
.
sleep
(
5
)
time
.
sleep
(
5
)
def
fork_new_process_for_each_test
(
f
):
_P
=
ParamSpec
(
"_P"
)
def
fork_new_process_for_each_test
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to fork a new process for each test function.
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
"""
@
functools
.
wraps
(
f
)
@
functools
.
wraps
(
f
)
def
wrapper
(
*
args
,
**
kwargs
)
:
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
# Make the process the leader of its own process group
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
# to avoid sending SIGTERM to the parent process
os
.
setpgrp
()
os
.
setpgrp
()
from
_pytest.outcomes
import
Skipped
from
_pytest.outcomes
import
Skipped
pid
=
os
.
fork
()
pid
=
os
.
fork
()
print
(
f
"Fork a new process to run a test
{
pid
}
"
)
if
pid
==
0
:
if
pid
==
0
:
try
:
try
:
f
(
*
args
,
**
kwargs
)
f
(
*
args
,
**
kwargs
)
...
@@ -388,11 +405,11 @@ def fork_new_process_for_each_test(f):
...
@@ -388,11 +405,11 @@ def fork_new_process_for_each_test(f):
pgid
=
os
.
getpgid
(
pid
)
pgid
=
os
.
getpgid
(
pid
)
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
# ignore SIGTERM signal itself
# ignore SIGTERM signal itself
old_si
ngla
_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
signal
.
SIG_IGN
)
old_si
gnal
_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
signal
.
SIG_IGN
)
# kill all child processes
# kill all child processes
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
# restore the signal handler
# restore the signal handler
signal
.
signal
(
signal
.
SIGTERM
,
old_si
ngla
_handler
)
signal
.
signal
(
signal
.
SIGTERM
,
old_si
gnal
_handler
)
assert
_exitcode
==
0
,
(
f
"function
{
f
}
failed when called with"
assert
_exitcode
==
0
,
(
f
"function
{
f
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
)
f
" args
{
args
}
and kwargs
{
kwargs
}
"
)
...
...
tests/weight_loading/models.txt
0 → 100644
View file @
af7f4372
gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
\ No newline at end of file
tests/weight_loading/run_model_weight_loading_test.sh
0 → 100644
View file @
af7f4372
#!/bin/bash
SUCCESS
=
0
IFS
=
$'
\n
'
read
-d
''
-r
-a
MODEL_CONFIGS <
"weight_loading/models.txt"
for
MODEL_CONFIG
in
"
${
MODEL_CONFIGS
[@]
}
"
do
LOCAL_SUCCESS
=
0
IFS
=
', '
read
-r
-a
array
<<<
"
$MODEL_CONFIG
"
echo
"=== RUNNING MODEL:
$MODEL_CONFIG
==="
export
QUANTIZATION
=
${
array
[0]
}
export
MODEL_NAME
=
${
array
[1]
}
export
REVISION
=
${
array
[2]
}
pytest
-s
weight_loading/test_weight_loading.py
||
LOCAL_SUCCESS
=
$?
if
[[
$LOCAL_SUCCESS
==
0
]]
;
then
echo
"=== PASSED MODEL:
${
MODEL_CONFIG
}
==="
else
echo
"=== FAILED MODEL:
${
MODEL_CONFIG
}
==="
fi
SUCCESS
=
$((
SUCCESS
+
LOCAL_SUCCESS
))
done
if
[
"
${
SUCCESS
}
"
-eq
"0"
]
;
then
exit
0
else
exit
1
fi
tests/weight_loading/test_weight_loading.py
0 → 100644
View file @
af7f4372
import
os
MAX_MODEL_LEN
=
1024
MODEL_NAME
=
os
.
environ
.
get
(
"MODEL_NAME"
,
"robertgshaw2/zephyr-7b-beta-channelwise-gptq"
)
REVISION
=
os
.
environ
.
get
(
"REVISION"
,
"main"
)
QUANTIZATION
=
os
.
environ
.
get
(
"QUANTIZATION"
,
"gptq_marlin"
)
def
test_weight_loading
(
vllm_runner
):
with
vllm_runner
(
model_name
=
MODEL_NAME
,
revision
=
REVISION
,
dtype
=
"auto"
,
quantization
=
QUANTIZATION
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
2
)
as
model
:
output
=
model
.
generate_greedy
(
"Hello world!"
,
max_tokens
=
20
)
print
(
output
)
assert
output
tests/worker/test_encoder_decoder_model_runner.py
0 → 100644
View file @
af7f4372
from
array
import
array
from
typing
import
List
import
pytest
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_cpu
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
# CUDA graph scenarios to test
#
# Currently CUDA graph is not supported
ENFORCE_EAGER
=
[
True
]
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
def
_create_model_runner
(
model
:
str
,
*
args
,
**
kwargs
)
->
EncoderDecoderModelRunner
:
engine_args
=
EngineArgs
(
model
,
*
args
,
**
kwargs
)
engine_config
=
engine_args
.
create_engine_config
()
model_runner
=
EncoderDecoderModelRunner
(
model_config
=
engine_config
.
model_config
,
parallel_config
=
engine_config
.
parallel_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
device_config
=
engine_config
.
device_config
,
cache_config
=
engine_config
.
cache_config
,
load_config
=
engine_config
.
load_config
,
lora_config
=
engine_config
.
lora_config
,
prompt_adapter_config
=
engine_config
.
prompt_adapter_config
,
is_driver_worker
=
True
,
)
return
model_runner
@
pytest
.
mark
.
skipif
(
condition
=
is_cpu
(),
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_empty_seq_group
(
enforce_eager
,
):
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
)
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input_tensors
(
seq_group_metadata_list
)
(
input_tokens
,
input_positions
,
encoder_input_tokens
,
encoder_input_positions
,
attn_metadata
,
return_seq_lens
,
)
=
(
model_input
.
input_tokens
,
model_input
.
input_positions
,
model_input
.
encoder_input_tokens
,
model_input
.
encoder_input_positions
,
model_input
.
attn_metadata
,
model_input
.
seq_lens
,
)
assert
input_tokens
is
None
assert
input_positions
is
None
assert
encoder_input_tokens
is
None
assert
encoder_input_positions
is
None
assert
attn_metadata
is
None
assert
return_seq_lens
is
None
@
pytest
.
mark
.
skipif
(
condition
=
is_cpu
(),
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_prepare_prompt
(
batch_size
,
enforce_eager
,
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
)
seq_lens
:
List
[
int
]
=
[]
encoder_seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
cross_block_table
=
[
2
]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
encoder_seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
)
assert
seq_group_metadata
.
token_chunk_size
==
seq_data
.
get_len
()
seq_group_metadata_list
.
append
(
seq_group_metadata
)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
attn_metadata
.
slot_mapping
encoder_input_tokens
=
model_input
.
encoder_input_tokens
encoder_input_positions
=
model_input
.
encoder_input_positions
cross_slot_mapping
=
attn_metadata
.
cross_slot_mapping
assert
return_seq_lens
==
seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
cross_slot_mapping
)
==
len
(
encoder_input_tokens
)
# Verify input metadata is correct for prompts.
# - Decoder attention metadata
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
>
0
assert
attn_metadata
.
num_decode_tokens
==
0
assert
torch
.
equal
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
max_prefill_seq_len
==
max
(
seq_lens
)
assert
attn_metadata
.
max_decode_seq_len
==
0
# - Encoder attention metadata
assert
attn_metadata
.
encoder_seq_lens
==
encoder_seq_lens
assert
torch
.
equal
(
attn_metadata
.
encoder_seq_lens_tensor
,
torch
.
tensor
(
encoder_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_encoder_seq_len
==
max
(
encoder_seq_lens
)
assert
attn_metadata
.
num_encoder_tokens
==
sum
(
encoder_seq_lens
)
# Test decoder subquery start locs.
start_idx
=
0
start_loc
=
[
start_idx
]
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
start_loc
.
append
(
start_idx
)
assert
torch
.
equal
(
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
# Test decoder seq start locs & context lengths
assert
torch
.
equal
(
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
assert
torch
.
equal
(
attn_metadata
.
context_lens_tensor
,
torch
.
zeros
(
attn_metadata
.
context_lens_tensor
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
device
),
)
# Verify block tables are correct for prompts
# - Decoder self-attention
expected
=
torch
.
tensor
(
[[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
expected
,
)
# - Encoder/decoder cross-attention
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
expected
,
)
# Cuda graph should not be used for prefill.
assert
attn_metadata
.
use_cuda_graph
is
False
# Verify the lengths of input tokens & positions
# - Decoder
assert
len
(
input_tokens
)
==
sum
(
seq_lens
)
assert
len
(
input_positions
)
==
sum
(
seq_lens
)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
input_tokens
,
input_positions
,
)
# - Encoder
assert
len
(
encoder_input_tokens
)
==
sum
(
encoder_seq_lens
)
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
encoder_input_tokens
,
encoder_input_positions
,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the prefill phase
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
seq_len
in
seq_lens
:
# Compute the index offset of the final token in each
# prompt (recall that the prompts are concatenated)
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
seq_len
-
1
)
selected_token_start_idx
+=
seq_len
sampling_metadata
=
model_input
.
sampling_metadata
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
,
)
assert
torch
.
equal
(
actual
,
expected
)
@
pytest
.
mark
.
skipif
(
condition
=
is_cpu
(),
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_prepare_decode
(
batch_size
,
enforce_eager
,
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
)
seq_lens
:
List
[
int
]
=
[]
encoder_seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
cross_block_table
=
[
2
]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
)
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
attn_metadata
.
slot_mapping
encoder_input_tokens
=
model_input
.
encoder_input_tokens
encoder_input_positions
=
model_input
.
encoder_input_positions
cross_slot_mapping
=
attn_metadata
.
cross_slot_mapping
assert
return_seq_lens
==
seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
cross_slot_mapping
)
==
len
(
encoder_input_tokens
)
# Verify input metadata is correct for decode phase.
# - Decoder attention metadata
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_decode_tokens
>
0
assert
torch
.
equal
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
max_prefill_seq_len
==
0
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
# - Encoder attention metadata
assert
attn_metadata
.
encoder_seq_lens
==
encoder_seq_lens
assert
torch
.
equal
(
attn_metadata
.
encoder_seq_lens_tensor
,
torch
.
tensor
(
encoder_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_encoder_seq_len
==
max
(
encoder_seq_lens
)
assert
attn_metadata
.
num_encoder_tokens
==
sum
(
encoder_seq_lens
)
# Test decoder subquery start locs.
start_idx
=
0
start_loc
=
[
start_idx
]
for
seq_len
in
seq_lens
:
start_idx
+=
1
start_loc
.
append
(
start_idx
)
assert
torch
.
equal
(
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
# Test decoder seq start locs. Note that for normal prefill it is
# equivalent to query_start_loc.
start_idx
=
0
seq_start_loc
=
[
start_idx
]
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
seq_start_loc
.
append
(
start_idx
)
# Test seq_start_loc and context lengths
assert
torch
.
equal
(
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
seq_start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
assert
torch
.
equal
(
attn_metadata
.
context_lens_tensor
,
torch
.
tensor
([
seq_len
-
1
for
seq_len
in
seq_lens
],
dtype
=
torch
.
int
,
device
=
device
))
# Verify block tables are correct for prompts
# - Decoder self-attention
expected
=
torch
.
tensor
(
[
block_tables
[
0
]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
expected
,
)
# - Encoder/decoder cross-attention
expected
=
torch
.
tensor
(
[
cross_block_table
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
expected
,
)
# Cuda graph should is currently not supported for encoder/decoer.
assert
attn_metadata
.
use_cuda_graph
is
False
# Verify the lengths of input tokens & positions
# - Decoder
assert
len
(
input_tokens
)
==
len
(
seq_lens
)
assert
len
(
input_positions
)
==
len
(
seq_lens
)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
input_tokens
,
input_positions
,
)
# - Encoder
assert
len
(
encoder_input_tokens
)
==
0
assert
len
(
encoder_input_tokens
)
==
0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
encoder_input_tokens
,
encoder_input_positions
,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the decode phase
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
seq_len
in
seq_lens
:
# Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length
# of the decoded tokens for a given sequence is 1 and
# the final index offset into a given sequence's
# generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
sampling_metadata
=
model_input
.
sampling_metadata
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
,
)
assert
torch
.
equal
(
actual
,
expected
)
tests/worker/test_model_input.py
View file @
af7f4372
...
@@ -5,11 +5,13 @@ import torch
...
@@ -5,11 +5,13 @@ import torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.worker.embedding_model_runner
import
(
from
vllm.worker.embedding_model_runner
import
(
ModelInputForGPUWithPoolingMetadata
)
ModelInputForGPUWithPoolingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
vllm.worker.multi_step_model_runner
import
StatefulModelInput
class
MockAttentionBackend
(
AttentionBackend
):
class
MockAttentionBackend
(
AttentionBackend
):
...
@@ -28,7 +30,11 @@ class MockAttentionBackend(AttentionBackend):
...
@@ -28,7 +30,11 @@ class MockAttentionBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"AttentionMetadataBuilder"
]:
raise
AttentionMetadataBuilder
return
AttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -154,3 +160,79 @@ def test_embedding_model_runner_input():
...
@@ -154,3 +160,79 @@ def test_embedding_model_runner_input():
None
)
==
getattr
(
attn_metadata
,
field
.
name
,
None
)
None
)
==
getattr
(
attn_metadata
,
field
.
name
,
None
)
# Pooling metadata is not broadcast.
# Pooling metadata is not broadcast.
assert
received_model_input
.
pooling_metadata
is
None
assert
received_model_input
.
pooling_metadata
is
None
def
test_multi_step_model_runner_input
():
sampling_metadata
=
SamplingMetadata
(
[
"seq_group"
],
"selected_token_indices"
,
"categorized_sample_indices"
,
"num_prompts"
,
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
1
,
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
)
frozen_model_input
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
input_positions
=
torch
.
ones
(
10
),
sampling_metadata
=
sampling_metadata
,
attn_metadata
=
attn_metadata
)
model_input
=
StatefulModelInput
(
frozen_model_input
=
frozen_model_input
,
is_last_step
=
True
,
is_first_multi_step
=
False
,
current_step
=
4
,
last_sampled_token_ids
=
torch
.
ones
((
10
,
1
)),
is_multi_step
=
True
,
num_queries
=
8
,
num_seqs
=
5
,
cached_outputs
=
[],
)
assert
isinstance
(
model_input
,
StatefulModelInput
)
# Test round trip serialization.
tensor_dict
=
model_input
.
as_broadcastable_tensor_dict
()
attn_backend
=
MockAttentionBackend
()
received_model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
attn_backend
))
receieved_frozen_input
=
received_model_input
.
frozen_model_input
# Check that received copy has correct values.
assert
isinstance
(
received_model_input
,
StatefulModelInput
)
assert
receieved_frozen_input
.
input_tokens
is
not
None
assert
(
receieved_frozen_input
.
input_tokens
==
frozen_model_input
.
input_tokens
).
all
()
assert
receieved_frozen_input
.
input_positions
is
not
None
assert
(
receieved_frozen_input
.
input_positions
==
frozen_model_input
.
input_positions
).
all
()
assert
receieved_frozen_input
.
multi_modal_kwargs
is
None
assert
(
frozen_model_input
.
multi_modal_kwargs
==
frozen_model_input
.
multi_modal_kwargs
)
assert
receieved_frozen_input
.
lora_requests
is
None
assert
(
receieved_frozen_input
.
lora_requests
==
frozen_model_input
.
lora_requests
)
assert
receieved_frozen_input
.
lora_mapping
is
None
assert
(
receieved_frozen_input
.
lora_mapping
==
frozen_model_input
.
lora_mapping
)
for
field
in
dataclasses
.
fields
(
AttentionMetadata
):
assert
getattr
(
receieved_frozen_input
.
attn_metadata
,
field
.
name
,
None
)
==
getattr
(
attn_metadata
,
field
.
name
,
None
)
# For sampling metadata, only selected_token_indices is copied.
assert
(
receieved_frozen_input
.
sampling_metadata
.
selected_token_indices
==
sampling_metadata
.
selected_token_indices
)
assert
receieved_frozen_input
.
sampling_metadata
.
seq_groups
is
None
# check non frozen fields
assert
received_model_input
.
is_last_step
==
model_input
.
is_last_step
assert
(
received_model_input
.
is_first_multi_step
==
model_input
.
is_first_multi_step
)
assert
received_model_input
.
current_step
==
model_input
.
current_step
assert
(
received_model_input
.
last_sampled_token_ids
==
model_input
.
last_sampled_token_ids
).
all
()
assert
received_model_input
.
is_multi_step
==
model_input
.
is_multi_step
tests/worker/test_model_runner.py
View file @
af7f4372
from
array
import
array
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
...
@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
...
@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
@@ -24,6 +26,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
...
@@ -24,6 +26,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
load_config
=
engine_config
.
load_config
,
load_config
=
engine_config
.
load_config
,
lora_config
=
engine_config
.
lora_config
,
lora_config
=
engine_config
.
lora_config
,
prompt_adapter_config
=
engine_config
.
prompt_adapter_config
,
prompt_adapter_config
=
engine_config
.
prompt_adapter_config
,
observability_config
=
engine_config
.
observability_config
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
return
model_runner
return
model_runner
...
@@ -45,7 +48,8 @@ def test_prepare_prompt(batch_size):
...
@@ -45,7 +48,8 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -76,7 +80,7 @@ def test_prepare_prompt(batch_size):
...
@@ -76,7 +80,7 @@ def test_prepare_prompt(batch_size):
device
=
model_runner
.
device
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
>
0
assert
attn_metadata
.
num_prefills
>
0
assert
attn_metadata
.
num_decode_tokens
==
0
assert
attn_metadata
.
num_decode_tokens
==
0
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
seq_lens
==
seq_lens
...
@@ -89,7 +93,7 @@ def test_prepare_prompt(batch_size):
...
@@ -89,7 +93,7 @@ def test_prepare_prompt(batch_size):
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
start_idx
+=
seq_len
start_loc
.
append
(
start_idx
)
start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
query_start_loc
,
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
...
@@ -101,10 +105,10 @@ def test_prepare_prompt(batch_size):
...
@@ -101,10 +105,10 @@ def test_prepare_prompt(batch_size):
start_idx
+=
seq_len
start_idx
+=
seq_len
seq_start_loc
.
append
(
start_idx
)
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_start_loc
,
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
context_lens_tensor
,
attn_metadata
.
context_lens_tensor
,
torch
.
zeros
(
attn_metadata
.
context_lens_tensor
.
shape
[
0
],
torch
.
zeros
(
attn_metadata
.
context_lens_tensor
.
shape
[
0
],
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
...
@@ -113,7 +117,7 @@ def test_prepare_prompt(batch_size):
...
@@ -113,7 +117,7 @@ def test_prepare_prompt(batch_size):
expected
=
torch
.
tensor
([[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
expected
=
torch
.
tensor
([[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
device
=
model_runner
.
device
)
assert
torch
.
all
close
(
attn_metadata
.
block_tables
,
expected
)
torch
.
testing
.
assert_
close
(
attn_metadata
.
block_tables
,
expected
)
# Cuda graph should not be used for prerill.
# Cuda graph should not be used for prerill.
assert
attn_metadata
.
use_cuda_graph
is
False
assert
attn_metadata
.
use_cuda_graph
is
False
...
@@ -162,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -162,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
context_lens
.
append
(
context_len
)
seq_data
=
SequenceData
(
list
(
range
(
context_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
)))
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
...
@@ -200,7 +205,7 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -200,7 +205,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# decode has only 1 token for query.
# decode has only 1 token for query.
start_idx
+=
1
start_idx
+=
1
start_loc
.
append
(
start_idx
)
start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
query_start_loc
,
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
...
@@ -209,15 +214,15 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -209,15 +214,15 @@ def test_prepare_decode_cuda_graph(batch_size):
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
start_idx
+=
seq_len
seq_start_loc
.
append
(
start_idx
)
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_start_loc
,
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
seq_start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
seq_start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
context_lens_tensor
,
attn_metadata
.
context_lens_tensor
,
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
device
))
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
device
))
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_lens_tensor
[:
len
(
seq_lens
)],
attn_metadata
.
seq_lens_tensor
[:
len
(
seq_lens
)],
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
))
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
))
...
@@ -323,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -323,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
list
(
range
(
seq_len
)))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -339,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -339,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
list
(
range
(
context_len
))
prompt_toks
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
...
...
vllm/_core_ext.py
View file @
af7f4372
import
importlib.util
import
importlib.util
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -31,14 +31,14 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -31,14 +31,14 @@ if TYPE_CHECKING or not core_C_available:
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ScalarType
:
class
ScalarType
:
"""
"""
ScalarType can represent a wide range of floating point and integer
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
with that file.
"""
"""
...
@@ -51,15 +51,15 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -51,15 +51,15 @@ if TYPE_CHECKING or not core_C_available:
mantissa
:
int
mantissa
:
int
"""
"""
Number of bits in the mantissa if this is a floating point type,
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
or the number bits representing an integer excluding the sign bit if
this an integer type.
this an integer type.
"""
"""
bias
:
int
bias
:
int
"""
"""
bias used to encode the values in this scalar type
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
"""
...
@@ -73,7 +73,7 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -73,7 +73,7 @@ if TYPE_CHECKING or not core_C_available:
nan_repr
:
int
=
NanRepr
.
IEEE_754
.
value
nan_repr
:
int
=
NanRepr
.
IEEE_754
.
value
"""
"""
How NaNs are represent in this scalar type, returns NanRepr value.
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
(not applicable for integer types)
"""
"""
...
@@ -83,14 +83,14 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -83,14 +83,14 @@ if TYPE_CHECKING or not core_C_available:
def
min
(
self
)
->
Union
[
int
,
float
]:
def
min
(
self
)
->
Union
[
int
,
float
]:
"""
"""
Min representable value for this scalar type.
Min representable value for this scalar type.
(accounting for bias if there is one)
(accounting for bias if there is one)
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
max
(
self
)
->
Union
[
int
,
float
]:
def
max
(
self
)
->
Union
[
int
,
float
]:
"""
"""
Max representable value for this scalar type.
Max representable value for this scalar type.
(accounting for bias if there is one)
(accounting for bias if there is one)
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -103,28 +103,28 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -103,28 +103,28 @@ if TYPE_CHECKING or not core_C_available:
"""
"""
...
...
def
is_floating_point
(
self
):
def
is_floating_point
(
self
)
->
bool
:
"If the type is a floating point type"
"If the type is a floating point type"
return
self
.
exponent
!=
0
return
self
.
exponent
!=
0
def
is_integer
(
self
):
def
is_integer
(
self
)
->
bool
:
"If the type is an integer type"
"If the type is an integer type"
return
self
.
exponent
==
0
return
self
.
exponent
==
0
def
has_bias
(
self
):
def
has_bias
(
self
)
->
bool
:
"If the type has a non-zero bias"
"If the type has a non-zero bias"
return
self
.
bias
!=
0
return
self
.
bias
!=
0
def
has_infs
(
self
):
def
has_infs
(
self
)
->
bool
:
"If the type is floating point and supports infinity"
"If the type is floating point and supports infinity"
return
not
self
.
_finite_values_only
return
not
self
.
_finite_values_only
def
has_nans
(
self
):
def
has_nans
(
self
)
->
bool
:
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
def
is_ieee_754
(
self
)
->
bool
:
def
is_ieee_754
(
self
)
->
bool
:
"""
"""
If the type is a floating point type that follows IEEE 754
If the type is a floating point type that follows IEEE 754
conventions
conventions
"""
"""
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
\
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
\
...
@@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
raise
NotImplementedError
raise
NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def
__len__
(
self
)
->
int
:
raise
TypeError
#
#
# Convenience Constructors
# Convenience Constructors
#
#
...
@@ -153,16 +158,16 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -153,16 +158,16 @@ if TYPE_CHECKING or not core_C_available:
@
classmethod
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
"""
"""
Create a standard floating point type
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
(i.e. follows IEEE 754 conventions).
"""
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
)
return
cls
(
exponent
,
mantissa
,
0
,
True
)
@
classmethod
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
):
nan_repr
:
int
)
->
'ScalarType'
:
"""
"""
Create a non-standard floating point type
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
(i.e. does not follow IEEE 754 conventions).
"""
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
,
finite_values_only
,
return
cls
(
exponent
,
mantissa
,
0
,
True
,
finite_values_only
,
...
@@ -175,3 +180,93 @@ elif core_C_available:
...
@@ -175,3 +180,93 @@ elif core_C_available:
logger
.
warning
(
"Failed to import from vllm._core_C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._core_C with %r"
,
e
)
ScalarType
=
torch
.
classes
.
_core_C
.
ScalarType
ScalarType
=
torch
.
classes
.
_core_C
.
ScalarType
# Needed for dynamo support of ScalarType.
@
torch
.
_library
.
register_fake_class
(
"_core_C::ScalarType"
)
class
FakeScalarType
:
def
__init__
(
self
,
scalar_type
):
self
.
ScalarType
=
scalar_type
def
bias_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
bias
def
exponent_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
exponent
def
mantissa_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
mantissa
def
signed_getter
(
self
)
->
bool
:
return
self
.
ScalarType
.
signed
def
size_bits_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
@
property
def
size_bits
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
def
min
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
min
()
def
max
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
max
()
def
is_signed
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_signed
()
def
is_floating_point
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_floating_point
()
def
is_integer
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_integer
()
def
has_bias
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_bias
()
def
has_infs
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_infs
()
def
has_nans
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_nans
()
def
is_ieee_754
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_ieee_754
()
def
__str__
(
self
)
->
str
:
return
self
.
ScalarType
.
__str__
()
def
__repr__
(
self
)
->
str
:
return
self
.
ScalarType
.
__repr__
()
def
__len__
(
self
)
->
int
:
return
self
.
ScalarType
.
__len__
()
def
__obj_flatten__
(
self
)
->
Tuple
[
Tuple
[
str
,
Any
],
...]:
return
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_flatten__
(
self
.
ScalarType
)
@
classmethod
def
__obj_unflatten__
(
cls
,
flat_type
:
Tuple
[
Tuple
[
str
,
Any
],
...])
->
'ScalarType'
:
return
cls
(
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_unflatten__
(
flat_type
))
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
int_
(
size_bits
,
bias
)
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
uint
(
size_bits
,
bias
)
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_IEEE754
(
exponent
,
mantissa
)
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_
(
exponent
,
mantissa
,
finite_values_only
,
nan_repr
)
vllm/_custom_ops.py
View file @
af7f4372
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
vllm._core_ext
import
ScalarType
from
vllm._core_ext
import
ScalarType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
try
:
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_ops
...
@@ -14,19 +15,14 @@ except Exception:
...
@@ -14,19 +15,14 @@ except Exception:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
try
:
if
not
current_platform
.
is_tpu
():
import
vllm._C
try
:
except
ImportError
as
e
:
import
vllm._C
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
# ruff: noqa: F401
import
vllm._moe_C
# noqa: F401
import
vllm._moe_C
def
is_custom_op_supported
(
op_name
:
str
)
->
bool
:
op
,
overloads
=
torch
.
_C
.
_jit_get_operation
(
op_name
)
return
op
is
not
None
def
hint_on_error
(
fn
):
def
hint_on_error
(
fn
):
...
@@ -375,6 +371,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
...
@@ -375,6 +371,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
m
=
a
.
shape
[
0
]
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
n
=
b
.
shape
[
1
]
...
@@ -385,17 +383,39 @@ def cutlass_scaled_mm(a: torch.Tensor,
...
@@ -385,17 +383,39 @@ def cutlass_scaled_mm(a: torch.Tensor,
return
out
return
out
def
cutlass_scaled_mm_azp
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
azp_adj
:
torch
.
Tensor
,
azp
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
numel
(
)
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
(
out
,
a
,
b
,
scale_a
,
scale_b
,
azp_adj
,
azp
,
bias
)
return
out
# aqlm
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
codebook_partition_sizes
:
List
[
int
]
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
return
torch
.
ops
.
_C
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
codebook_partition_sizes
:
List
[
int
]
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
aqlm_dequant
(
codes
,
codebooks
,
return
torch
.
ops
.
_C
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
codebook_partition_sizes
)
...
@@ -443,6 +463,32 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -443,6 +463,32 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
num_bits
,
size_m
,
size_n
,
size_k
)
num_bits
,
size_m
,
size_n
,
size_k
)
# machete
def
machete_supported_schedules
(
b_type
:
ScalarType
)
->
List
[
str
]:
return
torch
.
ops
.
_C
.
machete_supported_schedules
(
b_type
)
def
machete_gemm
(
a
:
torch
.
Tensor
,
b_q
:
torch
.
Tensor
,
# Should be the tensor returned by machete_prepack_B
b_type
:
ScalarType
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_group_size
:
Optional
[
int
]
=
None
,
c
:
Optional
[
torch
.
Tensor
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
schedule
:
Optional
[
str
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
machete_gemm
(
a
,
b_q
,
b_type
,
b_scales
,
b_zeros
,
b_group_size
,
c
,
alpha
,
beta
,
schedule
)
def
machete_prepack_B
(
b_q_weight
:
torch
.
Tensor
,
b_type
:
ScalarType
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
)
# fp8
# fp8
# def scaled_fp8_quant(
# def scaled_fp8_quant(
# input: torch.Tensor,
# input: torch.Tensor,
...
@@ -477,9 +523,12 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -477,9 +523,12 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# # This code assumes batch_dim and num_tokens are flattened
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# assert (input.ndim == 2)
# shape: Union[Tuple[int, int], torch.Size] = input.shape
# shape: Union[Tuple[int, int], torch.Size] = input.shape
# # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
# else torch.float8_e4m3fn
# if num_token_padding:
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# output = torch.empty(shape, device=input.device, dtype=
torch.float8_e4m3fn
)
# output = torch.empty(shape, device=input.device, dtype=
out_dtype
)
# if scale is None:
# if scale is None:
# if use_per_token_if_dynamic:
# if use_per_token_if_dynamic:
...
@@ -538,6 +587,30 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -538,6 +587,30 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
workspace
,
size_m
,
size_n
,
size_k
)
workspace
,
size_m
,
size_n
,
size_k
)
# gguf
def
ggml_dequantize
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
int
,
n
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_dequantize
(
W
,
quant_type
,
m
,
n
)
def
ggml_mul_mat_vec_a8
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_mul_mat_vec_a8
(
W
,
X
,
quant_type
,
row
)
def
ggml_mul_mat_a8
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_mul_mat_a8
(
W
,
X
,
quant_type
,
row
)
# moe
# moe
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
...
@@ -674,7 +747,7 @@ for k, v in names_and_values.items():
...
@@ -674,7 +747,7 @@ for k, v in names_and_values.items():
if
isinstance
(
v
,
fn_type
)
\
if
isinstance
(
v
,
fn_type
)
\
and
v
.
__code__
.
co_filename
==
__file__
\
and
v
.
__code__
.
co_filename
==
__file__
\
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
for
arg
in
v
.
__annotations__
.
values
()):
for
arg
in
v
.
__annotations__
.
values
()):
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values
.
update
(
names_and_values_to_update
)
names_and_values
.
update
(
names_and_values_to_update
)
...
...
vllm/adapter_commons/request.py
View file @
af7f4372
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
@
dataclass
class
AdapterRequest
(
ABC
):
class
AdapterRequest
(
ABC
):
"""
"""
Base class for adapter requests.
Base class for adapter requests.
...
...
vllm/assets/audio.py
0 → 100644
View file @
af7f4372
from
dataclasses
import
dataclass
from
typing
import
Literal
,
Tuple
from
urllib.parse
import
urljoin
import
librosa
import
numpy
as
np
from
vllm.assets.base
import
get_vllm_public_assets
,
vLLM_S3_BUCKET_URL
ASSET_DIR
=
"multimodal_asset"
@
dataclass
(
frozen
=
True
)
class
AudioAsset
:
name
:
Literal
[
"winning_call"
,
"mary_had_lamb"
]
@
property
def
audio_and_sample_rate
(
self
)
->
Tuple
[
np
.
ndarray
,
int
]:
audio_path
=
get_vllm_public_assets
(
filename
=
f
"
{
self
.
name
}
.ogg"
,
s3_prefix
=
ASSET_DIR
)
return
librosa
.
load
(
audio_path
,
sr
=
None
)
@
property
def
url
(
self
)
->
str
:
return
urljoin
(
vLLM_S3_BUCKET_URL
,
f
"
{
ASSET_DIR
}
/
{
self
.
name
}
.ogg"
)
Prev
1
…
9
10
11
12
13
14
15
16
17
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment