Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f48954a4
Commit
f48954a4
authored
Jun 12, 2024
by
zhuwenwen
Browse files
merge v0.5.0
parents
1dba29d3
8f89d720
Changes
253
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
553 additions
and
332 deletions
+553
-332
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+6
-14
tests/samplers/test_ignore_eos.py
tests/samplers/test_ignore_eos.py
+13
-11
tests/samplers/test_logits_processor.py
tests/samplers/test_logits_processor.py
+41
-41
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+37
-29
tests/samplers/test_ranks.py
tests/samplers/test_ranks.py
+21
-17
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+2
-3
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+6
-1
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+2
-2
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+12
-9
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+12
-9
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+55
-74
tests/test_config.py
tests/test_config.py
+6
-1
tests/test_regression.py
tests/test_regression.py
+21
-0
tests/test_sharded_state_loader.py
tests/test_sharded_state_loader.py
+82
-45
tests/test_utils.py
tests/test_utils.py
+15
-1
tests/tokenization/test_image_processor.py
tests/tokenization/test_image_processor.py
+20
-0
tests/utils.py
tests/utils.py
+2
-1
vllm/__init__.py
vllm/__init__.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+187
-62
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+12
-11
No files found.
tests/samplers/test_beam_search.py
View file @
f48954a4
...
@@ -2,10 +2,8 @@
...
@@ -2,10 +2,8 @@
Run `pytest tests/samplers/test_beam_search.py`.
Run `pytest tests/samplers/test_beam_search.py`.
"""
"""
import
gc
import
pytest
import
pytest
import
torch
# FIXME(zhuohan): The test can not pass if we:
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 1. Increase max_tokens to 256.
...
@@ -30,19 +28,13 @@ def test_beam_search_single_input(
...
@@ -30,19 +28,13 @@ def test_beam_search_single_input(
beam_width
:
int
,
beam_width
:
int
,
)
->
None
:
)
->
None
:
example_prompts
=
example_prompts
[:
1
]
example_prompts
=
example_prompts
[:
1
]
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
hf_outputs
=
hf_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_outputs
=
vllm_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
)
max_tokens
)
del
vllm_model
# NOTE(woosuk): For some reason, the following GC is required to avoid
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
# GPU OOM errors in the following tests using `vllm_runner`.
vllm_outputs
=
vllm_model
.
generate_beam_search
(
example_prompts
,
gc
.
collect
()
beam_width
,
max_tokens
)
torch
.
cuda
.
empty_cache
()
for
i
in
range
(
len
(
example_prompts
)):
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
_
=
hf_outputs
[
i
]
hf_output_ids
,
_
=
hf_outputs
[
i
]
...
...
tests/samplers/test_ignore_eos.py
View file @
f48954a4
...
@@ -7,25 +7,27 @@ import pytest
...
@@ -7,25 +7,27 @@ import pytest
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
MODELS
=
[
"facebook/opt-125m"
]
# We also test with llama because it has generation_config to specify EOS
# (past regression).
MODELS
=
[
"facebook/opt-125m"
,
"meta-llama/Llama-2-7b-hf"
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
512
])
def
test_
beam_search_single_input
(
def
test_
ignore_eos
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
)
->
None
:
)
->
None
:
example_prompts
=
"1 + 1 is"
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
ignore_eos
=
True
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
for
prompt
in
example_prompts
:
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
ignore_eos
=
True
)
ignore_eos_output
=
vllm_model
.
model
.
generate
(
ignore_eos_output
=
vllm_model
.
model
.
generate
(
prompt
,
sampling_params
=
sampling_params
)
example_prompts
,
sampling_params
=
sampling_params
)
output_length
=
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
print
(
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
))
assert
output_length
==
max_tokens
assert
max_tokens
-
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
<
10
assert
max_tokens
-
len
(
ignore_eos_output
[
0
].
outputs
[
0
].
token_ids
)
>=
0
tests/samplers/test_logits_processor.py
View file @
f48954a4
...
@@ -14,46 +14,46 @@ def test_logits_processor_force_generate(
...
@@ -14,46 +14,46 @@ def test_logits_processor_force_generate(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
)
->
None
:
)
->
None
:
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
repeat_times
=
2
repeat_times
=
2
enforced_answers
=
" vLLM"
enforced_answers
=
" vLLM"
vllm_token_ids
=
tokenizer
.
encode
(
enforced_answers
,
vllm_token_ids
=
tokenizer
.
encode
(
enforced_answers
,
add_special_tokens
=
False
)
add_special_tokens
=
False
)
max_tokens
=
len
(
vllm_token_ids
)
*
repeat_times
max_tokens
=
len
(
vllm_token_ids
)
*
repeat_times
def
pick_vllm
(
token_ids
,
logits
):
def
pick_vllm
(
token_ids
,
logits
):
token_id
=
vllm_token_ids
[
len
(
token_ids
)
%
len
(
vllm_token_ids
)]
token_id
=
vllm_token_ids
[
len
(
token_ids
)
%
len
(
vllm_token_ids
)]
logits
[
token_id
]
=
torch
.
finfo
(
logits
.
dtype
).
max
logits
[
token_id
]
=
torch
.
finfo
(
logits
.
dtype
).
max
return
logits
return
logits
params_with_logprobs
=
SamplingParams
(
params_with_logprobs
=
SamplingParams
(
logits_processors
=
[
pick_vllm
],
logits_processors
=
[
pick_vllm
],
prompt_logprobs
=
3
,
max_tokens
=
max_tokens
,
)
# test logits_processors when prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
example_prompts
[
0
],
params
=
params_with_logprobs
,
)
# test prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
example_prompts
[
1
],
params
=
SamplingParams
(
prompt_logprobs
=
3
,
prompt_logprobs
=
3
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
),
)
)
# test logits_processors when prompt_logprobs is not None
# test grouped requests
vllm_model
.
model
.
_add_request
(
vllm_model
.
model
.
_add_request
(
example_prompts
[
0
],
example_prompts
[
2
],
params
=
params_with_logprobs
,
params
=
SamplingParams
(
max_tokens
=
max_tokens
),
)
)
# test prompt_logprobs is not None
outputs
=
vllm_model
.
model
.
_run_engine
(
use_tqdm
=
False
)
vllm_model
.
model
.
_add_request
(
example_prompts
[
1
],
assert
outputs
[
0
].
outputs
[
0
].
text
==
enforced_answers
*
repeat_times
params
=
SamplingParams
(
prompt_logprobs
=
3
,
max_tokens
=
max_tokens
,
),
)
# test grouped requests
vllm_model
.
model
.
_add_request
(
example_prompts
[
2
],
params
=
SamplingParams
(
max_tokens
=
max_tokens
),
)
outputs
=
vllm_model
.
model
.
_run_engine
(
use_tqdm
=
False
)
assert
outputs
[
0
].
outputs
[
0
].
text
==
enforced_answers
*
repeat_times
tests/samplers/test_logprobs.py
View file @
f48954a4
...
@@ -12,6 +12,7 @@ MODELS = ["facebook/opt-125m"]
...
@@ -12,6 +12,7 @@ MODELS = ["facebook/opt-125m"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
6
])
# 32000 == vocab_size
@
pytest
.
mark
.
parametrize
(
"num_top_logprobs"
,
[
6
])
# 32000 == vocab_size
@
pytest
.
mark
.
parametrize
(
"detokenize"
,
[
True
,
False
])
def
test_get_prompt_logprobs
(
def
test_get_prompt_logprobs
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -19,6 +20,7 @@ def test_get_prompt_logprobs(
...
@@ -19,6 +20,7 @@ def test_get_prompt_logprobs(
dtype
,
dtype
,
chunked_prefill_token_size
:
int
,
chunked_prefill_token_size
:
int
,
num_top_logprobs
:
int
,
num_top_logprobs
:
int
,
detokenize
:
bool
,
example_prompts
,
example_prompts
,
):
):
max_num_seqs
=
256
max_num_seqs
=
256
...
@@ -30,27 +32,27 @@ def test_get_prompt_logprobs(
...
@@ -30,27 +32,27 @@ def test_get_prompt_logprobs(
max_num_batched_tokens
=
chunked_prefill_token_size
max_num_batched_tokens
=
chunked_prefill_token_size
max_tokens
=
5
max_tokens
=
5
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
example_prompts
,
example_prompts
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
)
)
del
hf_model
with
vllm_runner
(
vllm_model
=
vllm_runner
(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
,
max_logprobs
=
num_top_logprobs
,
enable_chunked_prefill
=
enable_chunked_prefill
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_
batched_tokens
=
max_num_batched_token
s
,
max_num_
seqs
=
max_num_seq
s
,
max_num_seqs
=
max_num_seqs
,
)
as
vllm_model
:
)
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_token
s
,
logprobs
=
num_top_logprob
s
,
logprobs
=
num_top_logprobs
,
prompt_
logprobs
=
num_top_logprobs
,
prompt_logprobs
=
num_top_logprobs
,
temperature
=
0.0
,
temperature
=
0.0
)
detokenize
=
detokenize
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
# Test whether logprobs are included in the results.
# Test whether logprobs are included in the results.
for
result
in
vllm_results
:
for
result
in
vllm_results
:
...
@@ -65,11 +67,16 @@ def test_get_prompt_logprobs(
...
@@ -65,11 +67,16 @@ def test_get_prompt_logprobs(
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens
.
append
(
output_string_from_most_likely_tokens
.
append
(
top_logprob
.
decoded_token
)
top_logprob
.
decoded_token
)
output_string_from_most_likely_tokens
=
""
.
join
(
output_string_from_most_likely_tokens
)
if
detokenize
:
assert
output_text
==
output_string_from_most_likely_tokens
,
(
output_string_from_most_likely_tokens
=
""
.
join
(
"The output text from the top logprob for each token position "
output_string_from_most_likely_tokens
)
"should be the same as the output text in the result."
)
assert
output_text
==
output_string_from_most_likely_tokens
,
(
"The output text from the top logprob for each token position "
"should be the same as the output text in the result."
)
else
:
assert
output_text
==
''
assert
output_string_from_most_likely_tokens
==
[
None
]
*
max_tokens
# The first prompt logprob is always None
# The first prompt logprob is always None
assert
result
.
prompt_logprobs
[
0
]
is
None
assert
result
.
prompt_logprobs
[
0
]
is
None
...
@@ -98,9 +105,10 @@ def test_get_prompt_logprobs(
...
@@ -98,9 +105,10 @@ def test_get_prompt_logprobs(
hf_logprob
[
i
][
-
1
][
token_id
].
item
(),
hf_logprob
[
i
][
-
1
][
token_id
].
item
(),
atol
=
1e-1
,
atol
=
1e-1
,
rtol
=
1e-1
)
rtol
=
1e-1
)
assert
isinstance
(
sample_logprob
.
decoded_token
,
str
),
(
if
detokenize
:
"The token should be decoded by the time it is returned "
assert
isinstance
(
sample_logprob
.
decoded_token
,
str
),
(
" to the user."
)
"The token should be decoded by the time it is returned"
" to the user."
)
# Test if prompt logprobs are correctly set.
# Test if prompt logprobs are correctly set.
for
vllm_result
in
vllm_results
:
for
vllm_result
in
vllm_results
:
...
...
tests/samplers/test_ranks.py
View file @
f48954a4
...
@@ -17,16 +17,27 @@ def test_ranks(
...
@@ -17,16 +17,27 @@ def test_ranks(
num_top_logprobs
=
5
num_top_logprobs
=
5
num_prompt_logprobs
=
5
num_prompt_logprobs
=
5
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
)
as
vllm_model
:
## Test greedy logprobs ranks
vllm_sampling_params
=
SamplingParams
(
temperature
=
0.0
,
## Test greedy logprobs ranks
top_p
=
1.0
,
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
temperature
=
0.0
,
logprobs
=
num_top_logprobs
,
top_p
=
1.0
,
prompt_logprobs
=
num_prompt_logprobs
)
max_tokens
=
max_tokens
,
vllm_results
=
vllm_model
.
generate_w_logprobs
(
example_prompts
,
logprobs
=
num_top_logprobs
,
vllm_sampling_params
)
prompt_logprobs
=
num_prompt_logprobs
)
vllm_results
=
vllm_model
.
generate_w_logprobs
(
example_prompts
,
vllm_sampling_params
)
## Test non-greedy logprobs ranks
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
top_p
=
1.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_top_logprobs
,
prompt_logprobs
=
num_prompt_logprobs
)
res
=
vllm_model
.
generate_w_logprobs
(
example_prompts
,
sampling_params
)
for
result
in
vllm_results
:
for
result
in
vllm_results
:
assert
result
[
2
]
is
not
None
assert
result
[
2
]
is
not
None
assert
len
(
result
[
2
])
==
len
(
result
[
0
])
assert
len
(
result
[
2
])
==
len
(
result
[
0
])
...
@@ -35,13 +46,6 @@ def test_ranks(
...
@@ -35,13 +46,6 @@ def test_ranks(
assert
token
in
logprobs
assert
token
in
logprobs
assert
logprobs
[
token
].
rank
==
1
assert
logprobs
[
token
].
rank
==
1
## Test non-greedy logprobs ranks
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
top_p
=
1.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_top_logprobs
,
prompt_logprobs
=
num_prompt_logprobs
)
res
=
vllm_model
.
generate_w_logprobs
(
example_prompts
,
sampling_params
)
for
result
in
res
:
for
result
in
res
:
assert
result
[
2
]
is
not
None
assert
result
[
2
]
is
not
None
assert
len
(
result
[
2
])
==
len
(
result
[
0
])
assert
len
(
result
[
2
])
==
len
(
result
[
0
])
...
...
tests/samplers/test_seeded_generate.py
View file @
f48954a4
...
@@ -17,9 +17,8 @@ RANDOM_SEEDS = list(range(5))
...
@@ -17,9 +17,8 @@ RANDOM_SEEDS = list(range(5))
@
pytest
.
fixture
@
pytest
.
fixture
def
vllm_model
(
vllm_runner
):
def
vllm_model
(
vllm_runner
):
vllm_model
=
vllm_runner
(
MODEL
,
dtype
=
"half"
)
with
vllm_runner
(
MODEL
,
dtype
=
"half"
)
as
vllm_model
:
yield
vllm_model
yield
vllm_model
del
vllm_model
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
...
...
tests/spec_decode/e2e/conftest.py
View file @
f48954a4
...
@@ -18,9 +18,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs
...
@@ -18,9 +18,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.multimodal
import
MultiModalData
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Logprob
,
MultiModalData
from
vllm.sequence
import
Logprob
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
random_uuid
from
vllm.utils
import
Counter
,
random_uuid
...
@@ -76,7 +77,11 @@ class AsyncLLM:
...
@@ -76,7 +77,11 @@ class AsyncLLM:
swap_space
=
swap_space
,
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
enforce_eager
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
# For now use ray for the distributed back-end, since
# we rely on the use of engine_use_ray=True to avoid
# reinitializing CUDA in the same process (driver worker)
engine_use_ray
=
True
,
engine_use_ray
=
True
,
distributed_executor_backend
=
"ray"
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
**
kwargs
,
)
)
...
...
tests/spec_decode/test_dynamic_spec_decode.py
View file @
f48954a4
...
@@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
...
@@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
if
queue_size
<
disable_by_batch_size
:
if
queue_size
<
disable_by_batch_size
:
# Should raise exception when executing the mocked draft model.
# Should raise exception when executing the mocked draft model.
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposer
.
get_
spec_
proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
num_lookahead_slots
=
k
),
)
else
:
else
:
# Should not execute the draft model because spec decode is disabled
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
# for all requests. Accordingly, the proposal length should be 0.
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_
spec_
proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
num_lookahead_slots
=
k
),
)
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
f48954a4
...
@@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len():
...
@@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len():
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations():
...
@@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations():
k
,
k
,
prompt_len
=
prompt_len
)
prompt_len
=
prompt_len
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k():
...
@@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k():
prev_output_token_len
=
prev_output_token_len
,
prev_output_token_len
=
prev_output_token_len
,
)
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_ngram_worker.py
View file @
f48954a4
...
@@ -50,9 +50,10 @@ def test_ngram_algo_correctness_for_single_no_match():
...
@@ -50,9 +50,10 @@ def test_ngram_algo_correctness_for_single_no_match():
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
proposal_len
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -117,9 +118,10 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
...
@@ -117,9 +118,10 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
proposal_len
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -188,9 +190,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
...
@@ -188,9 +190,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
proposal_len
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/tensorizer_loader/test_tensorizer.py
View file @
f48954a4
import
gc
import
json
import
json
import
os
import
os
import
subprocess
import
subprocess
...
@@ -7,7 +6,6 @@ from unittest.mock import MagicMock, patch
...
@@ -7,7 +6,6 @@ from unittest.mock import MagicMock, patch
import
openai
import
openai
import
pytest
import
pytest
import
ray
import
ray
import
torch
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
# yapf: disable
# yapf: disable
...
@@ -71,72 +69,66 @@ def test_can_deserialize_s3(vllm_runner):
...
@@ -71,72 +69,66 @@ def test_can_deserialize_s3(vllm_runner):
model_ref
=
"EleutherAI/pythia-1.4b"
model_ref
=
"EleutherAI/pythia-1.4b"
tensorized_path
=
f
"s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
tensorized_path
=
f
"s3://tensorized/
{
model_ref
}
/fp16/model.tensors"
loaded_hf_model
=
vllm_runner
(
model_ref
,
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
tensorized_path
,
tensorizer_uri
=
tensorized_path
,
num_readers
=
1
,
num_readers
=
1
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
s3_endpoint
=
"object.ord1.coreweave.com"
,
))
))
as
loaded_hf_model
:
deserialized_outputs
=
loaded_hf_model
.
generate
(
prompts
,
sampling_params
)
deserialized_outputs
=
loaded_hf_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
assert
deserialized_outputs
assert
deserialized_outputs
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
def
test_deserialized_encrypted_vllm_model_has_same_outputs
(
def
test_deserialized_encrypted_vllm_model_has_same_outputs
(
vllm_runner
,
tmp_path
):
vllm_runner
,
tmp_path
):
vllm_model
=
vllm_runner
(
model_ref
)
with
vllm_runner
(
model_ref
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
key_path
=
tmp_path
/
(
model_ref
+
".key"
)
key_path
=
tmp_path
/
(
model_ref
+
".key"
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
config_for_serializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
config_for_serializing
,
encryption_key_path
=
key_path
)
del
vllm_model
config_for_serializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
)
gc
.
collect
()
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
torch
.
cuda
.
empty_cache
()
config_for_serializing
,
encryption_key_path
=
key_path
)
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
encryption_keyfile
=
key_path
)
encryption_keyfile
=
key_path
)
loaded_vllm_model
=
vllm_runner
(
with
vllm_runner
(
model_ref
,
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
config_for_deserializing
)
model_loader_extra_config
=
config_for_deserializing
)
as
loaded_vllm_model
:
# noqa: E501
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
assert
outputs
==
deserialized_outputs
assert
outputs
==
deserialized_outputs
def
test_deserialized_hf_model_has_same_outputs
(
hf_runner
,
vllm_runner
,
def
test_deserialized_hf_model_has_same_outputs
(
hf_runner
,
vllm_runner
,
tmp_path
):
tmp_path
):
hf_model
=
hf_runner
(
model_ref
)
with
hf_runner
(
model_ref
)
as
hf_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
max_tokens
=
50
max_tokens
=
50
outputs
=
hf_model
.
generate_greedy
(
prompts
,
max_tokens
=
max_tokens
)
outputs
=
hf_model
.
generate_greedy
(
prompts
,
max_tokens
=
max_tokens
)
with
open_stream
(
model_path
,
"wb+"
)
as
stream
:
with
open_stream
(
model_path
,
"wb+"
)
as
stream
:
serializer
=
TensorSerializer
(
stream
)
serializer
=
TensorSerializer
(
stream
)
serializer
.
write_module
(
hf_model
.
model
)
serializer
.
write_module
(
hf_model
.
model
)
del
hf_model
gc
.
collect
()
with
vllm_runner
(
model_ref
,
torch
.
cuda
.
empty_cache
()
loaded_hf_model
=
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
model_loader_extra_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
tensorizer_uri
=
model_path
,
num_readers
=
1
,
num_readers
=
1
,
))
))
as
loaded_hf_model
:
deserialized_outputs
=
loaded_hf_model
.
generate_greedy
(
deserialized_outputs
=
loaded_hf_model
.
generate_greedy
(
prompts
,
max_tokens
=
max_tokens
)
prompts
,
max_tokens
=
max_tokens
)
assert
outputs
==
deserialized_outputs
assert
outputs
==
deserialized_outputs
def
test_vllm_model_can_load_with_lora
(
vllm_runner
,
tmp_path
):
def
test_vllm_model_can_load_with_lora
(
vllm_runner
,
tmp_path
):
...
@@ -150,16 +142,13 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
...
@@ -150,16 +142,13 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
test_prompts
=
create_test_prompts
(
lora_path
)
test_prompts
=
create_test_prompts
(
lora_path
)
# Serialize model before deserializing and binding LoRA adapters
# Serialize model before deserializing and binding LoRA adapters
vllm_model
=
vllm_runner
(
model_ref
,
)
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
TensorizerConfig
(
tensorizer_uri
=
model_path
))
TensorizerConfig
(
tensorizer_uri
=
model_path
))
del
vllm_model
with
vllm_runner
(
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
loaded_vllm_model
=
vllm_runner
(
model_ref
,
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
TensorizerConfig
(
model_loader_extra_config
=
TensorizerConfig
(
...
@@ -172,10 +161,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
...
@@ -172,10 +161,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
max_cpu_loras
=
2
,
max_cpu_loras
=
2
,
max_num_seqs
=
50
,
max_num_seqs
=
50
,
max_model_len
=
1000
,
max_model_len
=
1000
,
)
)
as
loaded_vllm_model
:
process_requests
(
loaded_vllm_model
.
model
.
llm_engine
,
test_prompts
)
process_requests
(
loaded_vllm_model
.
model
.
llm_engine
,
test_prompts
)
assert
loaded_vllm_model
assert
loaded_vllm_model
def
test_load_without_tensorizer_load_format
(
vllm_runner
):
def
test_load_without_tensorizer_load_format
(
vllm_runner
):
...
@@ -188,19 +177,15 @@ def test_load_without_tensorizer_load_format(vllm_runner):
...
@@ -188,19 +177,15 @@ def test_load_without_tensorizer_load_format(vllm_runner):
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
@
pytest
.
mark
.
skipif
(
not
is_curl_installed
(),
reason
=
"cURL is not installed"
)
def
test_openai_apiserver_with_tensorizer
(
vllm_runner
,
tmp_path
):
def
test_openai_apiserver_with_tensorizer
(
vllm_runner
,
tmp_path
):
## Serialize model
## Serialize model
vllm_model
=
vllm_runner
(
model_ref
,
)
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
TensorizerConfig
(
tensorizer_uri
=
model_path
))
model_loader_extra_config
=
{
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
"tensorizer_uri"
:
str
(
model_path
),
TensorizerConfig
(
tensorizer_uri
=
model_path
))
}
del
vllm_model
model_loader_extra_config
=
{
gc
.
collect
()
"tensorizer_uri"
:
str
(
model_path
),
torch
.
cuda
.
empty_cache
()
}
## Start OpenAI API server
## Start OpenAI API server
openai_args
=
[
openai_args
=
[
...
@@ -224,9 +209,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
...
@@ -224,9 +209,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
temperature
=
0.0
)
temperature
=
0.0
)
assert
completion
.
id
is
not
None
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
assert
len
(
completion
.
choices
)
==
1
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
assert
len
(
completion
.
choices
[
0
].
text
)
>=
5
completion
.
choices
[
0
].
text
)
>=
5
assert
completion
.
choices
[
0
].
finish_reason
==
"length"
assert
completion
.
choices
[
0
].
finish_reason
==
"length"
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
...
@@ -262,18 +246,15 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
...
@@ -262,18 +246,15 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
))
config
=
TensorizerConfig
(
tensorizer_uri
=
str
(
model_path
))
vllm_model
=
vllm_runner
(
model_ref
)
with
vllm_runner
(
model_ref
)
as
vllm_model
:
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
config
)
serialize_vllm_model
(
vllm_model
.
model
.
llm_engine
,
config
)
assert
is_vllm_tensorized
(
config
)
assert
is_vllm_tensorized
(
config
)
del
vllm_model
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
loaded_vllm_model
=
vllm_runner
(
model_ref
,
with
vllm_runner
(
model_ref
,
load_format
=
"tensorizer"
,
load_format
=
"tensorizer"
,
model_loader_extra_config
=
config
)
model_loader_extra_config
=
config
)
as
loaded_vllm_model
:
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
deserialized_outputs
=
loaded_vllm_model
.
generate
(
prompts
,
sampling_params
)
# noqa: E501
assert
outputs
==
deserialized_outputs
assert
outputs
==
deserialized_outputs
tests/test_config.py
View file @
f48954a4
...
@@ -63,8 +63,9 @@ def test_get_sliding_window():
...
@@ -63,8 +63,9 @@ def test_get_sliding_window():
assert
mistral_model_config
.
get_sliding_window
()
==
TEST_SLIDING_WINDOW
assert
mistral_model_config
.
get_sliding_window
()
==
TEST_SLIDING_WINDOW
def
test_rope_
scaling
():
def
test_rope_
customization
():
TEST_ROPE_SCALING
=
{
"type"
:
"dynamic"
,
"factor"
:
2.0
}
TEST_ROPE_SCALING
=
{
"type"
:
"dynamic"
,
"factor"
:
2.0
}
TEST_ROPE_THETA
=
16_000_000.0
LONGCHAT_ROPE_SCALING
=
{
"type"
:
"linear"
,
"factor"
:
8.0
}
LONGCHAT_ROPE_SCALING
=
{
"type"
:
"linear"
,
"factor"
:
8.0
}
llama_model_config
=
ModelConfig
(
llama_model_config
=
ModelConfig
(
...
@@ -76,6 +77,7 @@ def test_rope_scaling():
...
@@ -76,6 +77,7 @@ def test_rope_scaling():
seed
=
0
,
seed
=
0
,
)
)
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_scaling"
,
None
)
is
None
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_scaling"
,
None
)
is
None
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_theta"
,
None
)
==
500_000
assert
llama_model_config
.
max_model_len
==
8192
assert
llama_model_config
.
max_model_len
==
8192
llama_model_config
=
ModelConfig
(
llama_model_config
=
ModelConfig
(
...
@@ -86,9 +88,12 @@ def test_rope_scaling():
...
@@ -86,9 +88,12 @@ def test_rope_scaling():
dtype
=
"float16"
,
dtype
=
"float16"
,
seed
=
0
,
seed
=
0
,
rope_scaling
=
TEST_ROPE_SCALING
,
rope_scaling
=
TEST_ROPE_SCALING
,
rope_theta
=
TEST_ROPE_THETA
,
)
)
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_scaling"
,
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_scaling"
,
None
)
==
TEST_ROPE_SCALING
None
)
==
TEST_ROPE_SCALING
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_theta"
,
None
)
==
TEST_ROPE_THETA
assert
llama_model_config
.
max_model_len
==
16384
assert
llama_model_config
.
max_model_len
==
16384
longchat_model_config
=
ModelConfig
(
longchat_model_config
=
ModelConfig
(
...
...
tests/test_regression.py
View file @
f48954a4
...
@@ -53,6 +53,27 @@ def test_gc():
...
@@ -53,6 +53,27 @@ def test_gc():
assert
allocated
<
50
*
1024
*
1024
assert
allocated
<
50
*
1024
*
1024
def
test_model_from_modelscope
(
monkeypatch
):
# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
MODELSCOPE_MODEL_NAME
=
"qwen/Qwen1.5-0.5B-Chat"
monkeypatch
.
setenv
(
"VLLM_USE_MODELSCOPE"
,
"True"
)
try
:
llm
=
LLM
(
model
=
MODELSCOPE_MODEL_NAME
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
len
(
outputs
)
==
4
finally
:
monkeypatch
.
delenv
(
"VLLM_USE_MODELSCOPE"
,
raising
=
False
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
pytest
import
pytest
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
tests/test_sharded_state_loader.py
View file @
f48954a4
import
multiprocessing
as
mp
import
os
import
os
import
shutil
import
shutil
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
...
@@ -18,9 +19,7 @@ prompts = [
...
@@ -18,9 +19,7 @@ prompts = [
# Create a sampling params object.
# Create a sampling params object.
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
temperature
=
0
,
top_p
=
0.95
,
seed
=
0
,
max_tokens
=
256
,
max_tokens
=
256
,
ignore_eos
=
True
,
ignore_eos
=
True
,
)
)
...
@@ -40,51 +39,89 @@ def test_filter_subtensors():
...
@@ -40,51 +39,89 @@ def test_filter_subtensors():
filtered_state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
state_dict
)
filtered_state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
state_dict
)
assert
tuple
(
filtered_state_dict
.
keys
())
==
(
"a"
,
"b"
,
"c"
)
assert
tuple
(
filtered_state_dict
.
keys
())
==
(
"a"
,
"b"
,
"c"
)
for
key
,
tensor
in
filtered_state_dict
.
items
():
for
key
,
tensor
in
filtered_state_dict
.
items
():
assert
tensor
.
equal
(
state_dict
[
key
])
# NOTE: don't use `euqal` here, as the tensor might contain NaNs
assert
tensor
is
state_dict
[
key
]
@
pytest
.
fixture
(
scope
=
"module"
)
def
llama_2_7b_files
():
with
TemporaryDirectory
()
as
cache_dir
:
input_dir
=
snapshot_download
(
"meta-llama/Llama-2-7b-hf"
,
cache_dir
=
cache_dir
,
ignore_patterns
=
"*.bin*"
)
yield
input_dir
def
_run_writer
(
input_dir
,
output_dir
,
weights_patterns
,
**
kwargs
):
llm_sharded_writer
=
LLM
(
model
=
input_dir
,
**
kwargs
)
# Dump worker states to output directory
llm_sharded_writer
.
llm_engine
.
model_executor
.
save_sharded_state
(
path
=
output_dir
)
# Copy metadata files to output directory
for
file
in
os
.
listdir
(
input_dir
):
if
not
any
(
file
.
endswith
(
ext
)
for
ext
in
weights_patterns
):
shutil
.
copy
(
f
"
{
input_dir
}
/
{
file
}
"
,
output_dir
)
def
_run_generate
(
input_dir
,
queue
:
mp
.
Queue
,
**
kwargs
):
llm
=
LLM
(
model
=
input_dir
,
**
kwargs
)
gen
=
llm
.
generate
(
prompts
,
sampling_params
)
queue
.
put
([
g
.
outputs
[
0
].
__dict__
for
g
in
gen
])
queue
.
close
()
queue
.
join_thread
()
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enable_lora"
,
[
False
,
True
])
def
test_sharded_state_loader
(
enable_lora
):
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
weights_patterns
=
(
"*.bin"
,
"*.pt"
,
"*.safetensors"
)
def
test_sharded_state_loader
(
enable_lora
,
tp_size
,
num_gpus_available
,
llama_2_7b_files
):
if
num_gpus_available
<
tp_size
:
pytest
.
skip
(
f
"Not enough GPUs for tensor parallelism
{
tp_size
}
"
)
with
TemporaryDirectory
()
as
cache_dir
,
TemporaryDirectory
()
as
output_dir
:
weights_patterns
=
(
"*.safetensors"
,
)
input_dir
=
snapshot_download
(
"meta-llama/Llama-2-7b-hf"
,
gpu_memory_utilization
=
0.8
cache_dir
=
cache_dir
)
input_dir
=
llama_2_7b_files
ctx
=
mp
.
get_context
(
"spawn"
)
llm
=
LLM
(
model
=
input_dir
,
# Run in separate processes for memory & CUDA isolation
worker_use_ray
=
True
,
with
TemporaryDirectory
()
as
output_dir
:
gpu_memory_utilization
=
0.3
,
p
=
ctx
.
Process
(
target
=
_run_writer
,
)
args
=
(
input_dir
,
output_dir
,
weights_patterns
),
kwargs
=
dict
(
# Dump worker states to output directory
tensor_parallel_size
=
tp_size
,
model_executor
=
llm
.
llm_engine
.
model_executor
distributed_executor_backend
=
"mp"
,
model_executor
.
save_sharded_state
(
path
=
output_dir
)
gpu_memory_utilization
=
gpu_memory_utilization
,
# Copy metadata files to output directory
enforce_eager
=
True
,
for
file
in
os
.
listdir
(
input_dir
):
))
if
not
any
(
file
.
endswith
(
ext
)
for
ext
in
weights_patterns
):
p
.
start
()
shutil
.
copy
(
f
"
{
input_dir
}
/
{
file
}
"
,
output_dir
)
p
.
join
()
del
llm
.
llm_engine
.
model_executor
queue
=
ctx
.
Queue
()
llm_before
=
LLM
(
model
=
input_dir
,
p
=
ctx
.
Process
(
target
=
_run_generate
,
worker_use_ray
=
True
,
args
=
(
input_dir
,
queue
),
enable_lora
=
enable_lora
,
kwargs
=
dict
(
gpu_memory_utilization
=
0.3
,
distributed_executor_backend
=
"mp"
,
)
enable_lora
=
enable_lora
,
gen_before
=
llm_before
.
generate
(
prompts
,
sampling_params
)
gpu_memory_utilization
=
gpu_memory_utilization
,
out_before
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_before
]
tensor_parallel_size
=
tp_size
,
del
llm_before
.
llm_engine
.
model_executor
))
p
.
start
()
llm_after
=
LLM
(
p
.
join
()
model
=
output_dir
,
out_before
=
queue
.
get
()
worker_use_ray
=
True
,
enable_lora
=
enable_lora
,
p
=
ctx
.
Process
(
target
=
_run_generate
,
gpu_memory_utilization
=
0.3
,
args
=
(
output_dir
,
queue
),
load_format
=
"sharded_state"
,
kwargs
=
dict
(
)
distributed_executor_backend
=
"mp"
,
gen_after
=
llm_after
.
generate
(
prompts
,
sampling_params
)
enable_lora
=
enable_lora
,
out_after
=
[
gen
.
outputs
[
0
].
__dict__
for
gen
in
gen_after
]
gpu_memory_utilization
=
gpu_memory_utilization
,
del
llm_after
.
llm_engine
.
model_executor
tensor_parallel_size
=
tp_size
,
load_format
=
"sharded_state"
,
))
p
.
start
()
p
.
join
()
out_after
=
queue
.
get
()
assert
out_before
==
out_after
assert
out_before
==
out_after
tests/test_utils.py
View file @
f48954a4
import
asyncio
import
asyncio
import
os
import
socket
import
sys
import
sys
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncIterator
,
Awaitable
,
Protocol
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncIterator
,
Awaitable
,
Protocol
,
Tuple
,
TypeVar
)
Tuple
,
TypeVar
)
import
pytest
import
pytest
from
vllm.utils
import
deprecate_kwargs
,
merge_async_iterators
from
vllm.utils
import
deprecate_kwargs
,
get_open_port
,
merge_async_iterators
from
.utils
import
error_on_warning
from
.utils
import
error_on_warning
...
@@ -116,3 +118,15 @@ def test_deprecate_kwargs_additional_message():
...
@@ -116,3 +118,15 @@ def test_deprecate_kwargs_additional_message():
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"abcd"
):
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"abcd"
):
dummy
(
old_arg
=
1
)
dummy
(
old_arg
=
1
)
def
test_get_open_port
():
os
.
environ
[
"VLLM_PORT"
]
=
"5678"
# make sure we can get multiple ports, even if the env var is set
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s1
:
s1
.
bind
((
"localhost"
,
get_open_port
()))
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s2
:
s2
.
bind
((
"localhost"
,
get_open_port
()))
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s3
:
s3
.
bind
((
"localhost"
,
get_open_port
()))
os
.
environ
.
pop
(
"VLLM_PORT"
)
tests/tokenization/test_image_processor.py
0 → 100644
View file @
f48954a4
import
pytest
from
transformers.image_processing_utils
import
BaseImageProcessor
from
vllm.transformers_utils.image_processor
import
get_image_processor
IMAGE_PROCESSOR_NAMES
=
[
"llava-hf/llava-1.5-7b-hf"
,
"llava-hf/llava-v1.6-34b-hf"
,
]
@
pytest
.
mark
.
parametrize
(
"processor_name"
,
IMAGE_PROCESSOR_NAMES
)
def
test_image_processor_revision
(
processor_name
:
str
):
# Assume that "main" branch always exists
image_processor
=
get_image_processor
(
processor_name
,
revision
=
"main"
)
assert
isinstance
(
image_processor
,
BaseImageProcessor
)
# Assume that "never" branch always does not exist
with
pytest
.
raises
(
OSError
,
match
=
'not a valid git identifier'
):
get_image_processor
(
processor_name
,
revision
=
"never"
)
tests/utils.py
View file @
f48954a4
...
@@ -24,7 +24,8 @@ class ServerRunner:
...
@@ -24,7 +24,8 @@ class ServerRunner:
env
=
os
.
environ
.
copy
()
env
=
os
.
environ
.
copy
()
env
[
"PYTHONUNBUFFERED"
]
=
"1"
env
[
"PYTHONUNBUFFERED"
]
=
"1"
self
.
proc
=
subprocess
.
Popen
(
self
.
proc
=
subprocess
.
Popen
(
[
"python3"
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
args
,
[
sys
.
executable
,
"-m"
,
"vllm.entrypoints.openai.api_server"
]
+
args
,
env
=
env
,
env
=
env
,
stdout
=
sys
.
stdout
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stderr
,
stderr
=
sys
.
stderr
,
...
...
vllm/__init__.py
View file @
f48954a4
...
@@ -13,7 +13,7 @@ from vllm.pooling_params import PoolingParams
...
@@ -13,7 +13,7 @@ from vllm.pooling_params import PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__dcu_version__
from
vllm.version
import
__dcu_version__
__version__
=
"0.
4.3
"
__version__
=
"0.
5.0
"
__all__
=
[
__all__
=
[
"LLM"
,
"LLM"
,
...
...
vllm/_custom_ops.py
View file @
f48954a4
from
typing
import
Optional
,
Tuple
,
Type
import
contextlib
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
try
:
try
:
from
vllm._C
import
cache_ops
as
vllm_cache_ops
import
vllm._C
from
vllm._C
import
ops
as
vllm_ops
except
ImportError
as
e
:
except
ImportError
:
from
vllm.logger
import
init_logger
pass
logger
=
init_logger
(
__name__
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
with
contextlib
.
suppress
(
ImportError
):
import
vllm._moe_C
with
contextlib
.
suppress
(
ImportError
):
# ruff: noqa: F401
import
vllm._punica_C
def
is_custom_op_supported
(
op_name
:
str
)
->
bool
:
op
,
overloads
=
torch
.
_C
.
_jit_get_operation
(
op_name
)
return
op
is
not
None
# activation ops
# activation ops
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
silu_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
silu_and_mul
(
out
,
x
)
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_and_mul
(
out
,
x
)
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_tanh_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_fast
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_fast
(
out
,
x
)
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_new
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
gelu_new
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_new
(
out
,
x
)
# page attention ops
# page attention ops
...
@@ -51,7 +65,7 @@ def paged_attention_v1(
...
@@ -51,7 +65,7 @@ def paged_attention_v1(
blocksparse_block_size
:
int
=
64
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
)
->
None
:
vllm_ops
.
paged_attention_v1
(
torch
.
ops
.
_C
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
...
@@ -81,7 +95,7 @@ def paged_attention_v2(
...
@@ -81,7 +95,7 @@ def paged_attention_v2(
blocksparse_block_size
:
int
=
64
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
)
->
None
:
vllm_ops
.
paged_attention_v2
(
torch
.
ops
.
_C
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
...
@@ -98,8 +112,8 @@ def rotary_embedding(
...
@@ -98,8 +112,8 @@ def rotary_embedding(
cos_sin_cache
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
is_neox
:
bool
,
)
->
None
:
)
->
None
:
vllm_ops
.
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
torch
.
ops
.
_C
.
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
is_neox
)
cos_sin_cache
,
is_neox
)
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -107,20 +121,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
...
@@ -107,20 +121,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
torch
.
ops
.
_C
.
batched_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
cos_sin_cache_offsets
)
# layer norm ops
# layer norm ops
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
epsilon
:
float
)
->
None
:
vllm_ops
.
rms_norm
(
out
,
input
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
rms_norm
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
vllm_ops
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# quantization ops
# quantization ops
...
@@ -128,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
...
@@ -128,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
thy
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
awq_dequantize
(
qweight
,
scales
,
zeros
,
split_k_iters
,
thx
,
return
torch
.
ops
.
_C
.
awq_dequantize
(
qweight
,
scales
,
zeros
,
split_k_iters
,
thy
)
thx
,
thy
)
def
awq_gemm
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
def
awq_gemm
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
awq_gemm
(
input
,
qweight
,
qzeros
,
scales
,
split_k_iters
)
return
torch
.
ops
.
_C
.
awq_gemm
(
input
,
qweight
,
qzeros
,
scales
,
split_k_iters
)
# gptq
# gptq
...
@@ -142,27 +156,27 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -142,27 +156,27 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
bit
:
int
)
->
torch
.
Tensor
:
bit
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
return
torch
.
ops
.
_C
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
use_exllama
,
bit
)
b_g_idx
,
use_exllama
,
bit
)
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
bit
:
int
)
->
None
:
bit
:
int
)
->
None
:
vllm_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# squeezellm
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
lookup_table
:
torch
.
Tensor
)
->
None
:
lookup_table
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
torch
.
ops
.
_C
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
# marlin
# marlin
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
size_m
,
return
torch
.
ops
.
_C
.
marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
size_m
,
size_n
,
size_k
)
size_n
,
size_k
)
# marlin_24
# marlin_24
...
@@ -170,14 +184,14 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -170,14 +184,14 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
return
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
workspace
,
num_bits
,
size_m
,
size_k
)
size_n
,
size_k
)
# cutlass
# cutlass
def
cutlass_scaled_mm_dq
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
def
cutlass_scaled_mm_dq
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
a_
scale
s
:
torch
.
Tensor
,
b_
scale
s
:
torch
.
Tensor
,
scale
_a
:
torch
.
Tensor
,
scale
_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
...
@@ -186,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
...
@@ -186,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n
=
b
.
shape
[
1
]
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
vllm_ops
.
cutlass_scaled_mm_dq
(
out
,
a
,
b
,
a_
scale
s
,
b_
scale
s
)
torch
.
ops
.
_C
.
cutlass_scaled_mm_dq
(
out
,
a
,
b
,
scale
_a
,
scale
_b
)
return
out
return
out
...
@@ -196,21 +210,22 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
...
@@ -196,21 +210,22 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
return
torch
.
ops
.
_C
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
return
torch
.
ops
.
_C
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
# gptq_marlin
# gptq_marlin
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
num_bits
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
return
torch
.
ops
.
_C
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
num_bits
)
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
...
@@ -218,9 +233,9 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -218,9 +233,9 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
)
->
torch
.
Tensor
:
is_k_full
:
bool
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
workspace
,
num_bits
,
size_m
,
size_n
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
)
size_k
,
is_k_full
)
# fp8
# fp8
...
@@ -257,28 +272,40 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -257,28 +272,40 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# if scale is None:
# if scale is None:
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
#
vllm_ops
.dynamic_scaled_fp8_quant(output, input, scale)
#
torch.ops._C
.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# else:
#
vllm_ops
.static_scaled_fp8_quant(output, input, scale)
#
torch.ops._C
.static_scaled_fp8_quant(output, input, scale)
# return output, scale
# return output, scale
# int8
# int8
def
static_scaled_int8_quant
(
input
:
torch
.
Tensor
,
def
scaled_int8_quant
(
scale
:
float
)
->
torch
.
Tensor
:
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Quantize the input tensor to int8 and return the quantized tensor.
Quantize the input tensor to int8 and return the quantized tensor
and scale
.
Args:
Args:
input: The input tensor to be quantized to int8.
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
Returns:
Returns:
t
orch.Tensor: Output tensor
in int8
.
Tuple[Torch.Tensor, T
orch.Tensor
]
: Output
int8
tensor
and scales
.
"""
"""
q
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
int8
)
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
int8
)
vllm_ops
.
static_scaled_int8_quant
(
q
,
input
,
scale
)
if
scale
is
not
None
:
return
q
# static-per-tensor quantization.
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
output
,
input
,
scale
)
return
output
,
scale
# dynamic-per-token quantization.
input_scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
output
,
input
,
input_scales
)
return
output
,
input_scales
# moe
# moe
...
@@ -286,9 +313,16 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
...
@@ -286,9 +313,16 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
)
->
None
:
num_tokens_post_pad
:
torch
.
Tensor
)
->
None
:
vllm_ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
torch
.
ops
.
_C
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_token_ids
,
experts_ids
,
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
)
num_tokens_post_pad
)
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indicies
:
torch
.
Tensor
,
gating_output
:
float
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
)
def
reshape_and_cache
(
def
reshape_and_cache
(
...
@@ -300,8 +334,9 @@ def reshape_and_cache(
...
@@ -300,8 +334,9 @@ def reshape_and_cache(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
kv_scale
:
float
,
)
->
None
:
)
->
None
:
vllm_cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
value_cache
,
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
def
reshape_and_cache_flash
(
def
reshape_and_cache_flash
(
...
@@ -312,25 +347,115 @@ def reshape_and_cache_flash(
...
@@ -312,25 +347,115 @@ def reshape_and_cache_flash(
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
vllm_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
slot_mapping
,
kv_cache_dtype
)
value_cache
,
slot_mapping
,
kv_cache_dtype
)
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm
_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
torch
.
ops
.
_C
_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm
_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
torch
.
ops
.
_C
_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
def
convert_fp8
(
output
:
torch
.
Tensor
,
def
convert_fp8
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
kv_dtype
:
str
=
"fp8"
)
->
None
:
kv_dtype
:
str
=
"fp8"
)
->
None
:
vllm_cache_ops
.
convert_fp8
(
output
,
input
,
scale
,
kv_dtype
)
torch
.
ops
.
_C_cache_ops
.
convert_fp8
(
output
,
input
,
scale
,
kv_dtype
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
return
torch
.
ops
.
_C_cuda_utils
.
get_device_attribute
(
attribute
,
device
)
def
get_max_shared_memory_per_block_device_attribute
(
device
:
int
)
->
int
:
# ruff: noqa: E501
return
torch
.
ops
.
_C_cuda_utils
.
get_max_shared_memory_per_block_device_attribute
(
device
)
# custom ar
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
should_custom_ar
(
inp
:
torch
.
Tensor
,
max_size
:
int
,
world_size
:
int
,
full_nvlink
:
bool
)
->
bool
:
return
torch
.
ops
.
_C_custom_ar
.
should_custom_ar
(
inp
,
max_size
,
world_size
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
#TODO: cuda_utils, custom_ar
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
str
],
List
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# punica
def
dispatch_bgmv
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
Tensor
,
layer_idx
:
int
,
scale
:
float
,
)
->
None
:
torch
.
ops
.
_punica_C
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
def
dispatch_bgmv_low_level
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
Tensor
,
layer_idx
:
int
,
scale
:
float
,
h_in
:
int
,
h_out
:
int
,
y_offset
:
int
,
)
->
None
:
torch
.
ops
.
_punica_C
.
dispatch_bgmv_low_level
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
,
h_in
,
h_out
,
y_offset
,
)
vllm/attention/backends/flash_attn.py
View file @
f48954a4
...
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
...
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
._C
import
cache_
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
AttentionMetadata
)
...
@@ -47,11 +47,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -47,11 +47,11 @@ class FlashAttentionBackend(AttentionBackend):
)
->
None
:
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
cache_
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
cache_
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
...
@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
)
->
None
:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
cache_
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
@
dataclass
...
@@ -285,7 +285,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -285,7 +285,7 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
# not cached. This happens during the initial memory profiling run.
cache_
ops
.
reshape_and_cache_flash
(
ops
.
reshape_and_cache_flash
(
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
...
@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# normal attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
out
=
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
out
=
output
[:
num_prefill_tokens
],
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
block_table
=
prefill_meta
.
block_tables
,
out
=
output
[:
num_prefill_tokens
],
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
decode_query
.
unsqueeze
(
1
),
key_cache
,
key_cache
,
value_cache
,
value_cache
,
...
@@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
).
squeeze
(
1
)
out
=
output
[
num_prefill_tokens
:].
unsqueeze
(
1
),
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
Prev
1
…
4
5
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment