Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
661 additions
and
187 deletions
+661
-187
tests/spec_decode/e2e/test_eagle_correctness.py
tests/spec_decode/e2e/test_eagle_correctness.py
+58
-0
tests/spec_decode/e2e/test_logprobs.py
tests/spec_decode/e2e/test_logprobs.py
+54
-41
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+60
-1
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+56
-1
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+59
-0
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+3
-9
tests/test_cache_block_hashing.py
tests/test_cache_block_hashing.py
+1
-4
tests/test_embedded_commit.py
tests/test_embedded_commit.py
+4
-3
tests/test_logger.py
tests/test_logger.py
+2
-2
tests/test_logits_processor.py
tests/test_logits_processor.py
+2
-6
tests/test_sequence.py
tests/test_sequence.py
+2
-5
tests/tpu/test_custom_dispatcher.py
tests/tpu/test_custom_dispatcher.py
+7
-0
tests/utils.py
tests/utils.py
+3
-2
tests/weight_loading/models-large.txt
tests/weight_loading/models-large.txt
+3
-1
tests/weight_loading/run_model_weight_loading_test.sh
tests/weight_loading/run_model_weight_loading_test.sh
+0
-0
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+214
-48
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+6
-14
use_existing_torch.py
use_existing_torch.py
+18
-0
vllm/__init__.py
vllm/__init__.py
+3
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+106
-48
No files found.
tests/spec_decode/e2e/test_eagle_correctness.py
View file @
539aa992
...
@@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
batch_size
,
output_len
,
seed
)
batch_size
,
output_len
,
seed
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs_during_spec_decoding"
:
False
,
},
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs_during_spec_decoding"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
128
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_eagle_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
output_len
,
seed
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
tests/spec_decode/e2e/test_logprobs.py
View file @
539aa992
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
.conftest
import
run_
logprob
_correctness_test
from
.conftest
import
run_
equality
_correctness_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -25,6 +25,10 @@ from .conftest import run_logprob_correctness_test
...
@@ -25,6 +25,10 @@ from .conftest import run_logprob_correctness_test
"speculative_model"
:
"JackFram/llama-160m"
,
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
"disable_logprobs_during_spec_decoding"
:
False
,
"disable_logprobs_during_spec_decoding"
:
False
,
},
{
"speculative_model"
:
"JackFram/llama-160m"
,
"num_speculative_tokens"
:
3
,
"disable_logprobs_during_spec_decoding"
:
True
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
...
@@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
seed
:
int
,
logprobs
:
int
):
seed
:
int
,
logprobs
:
int
):
"""Verify output logprobs are equal with and without speculative decoding.
"""Verify output logprobs are equal with and without speculative decoding.
"""
"""
run_logprob_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
batch_size
,
output_len
,
output_len
,
seed
,
seed
,
temperature
=
0.0
,
temperature
=
0.0
,
logprobs
=
logprobs
)
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
...
@@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Veriy logprob greedy equality with different speculation lens.
"""Veriy logprob greedy equality with different speculation lens.
"""
"""
run_logprob_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
batch_size
,
output_len
,
output_len
,
seed
,
seed
,
temperature
=
0.0
,
temperature
=
0.0
,
logprobs
=
logprobs
)
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
...
@@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed
:
int
,
logprobs
:
int
):
seed
:
int
,
logprobs
:
int
):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
"""
run_logprob_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
batch_size
,
output_len
,
output_len
,
seed
,
seed
,
temperature
=
0.0
,
temperature
=
0.0
,
logprobs
=
logprobs
)
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
...
@@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled.
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
Token choices should match with the base model.
"""
"""
run_logprob_correctness_test
(
vllm_runner
,
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
batch_size
,
output_len
,
output_len
,
seed
,
seed
,
temperature
=
0.0
,
temperature
=
0.0
,
logprobs
=
logprobs
)
logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
539aa992
...
@@ -31,7 +31,7 @@ MAIN_MODEL = "JackFram/llama-68m"
...
@@ -31,7 +31,7 @@ MAIN_MODEL = "JackFram/llama-68m"
# speculative model
# speculative model
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
SPEC_MODEL
=
"abhigoyal/vllm-medusa-llama-68m-random"
# max
.
number of speculative tokens: this corresponds to
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS
=
5
MAX_SPEC_TOKENS
=
5
...
@@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs_during_spec_decoding"
:
False
,
},
{
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"disable_logprobs_during_spec_decoding"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_medusa_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify greedy equality with different batch size."""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
539aa992
...
@@ -16,7 +16,7 @@ However, we still need to verify below scenario could be passed:
...
@@ -16,7 +16,7 @@ However, we still need to verify below scenario could be passed:
* Test greedy equality under various number of speculative tokens.
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, MLPSpeculator would not break the
With those tests, we can say at least, MLPSpeculator would not break the
correctess for the target model outputs.
correct
n
ess for the target model outputs.
"""
"""
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
"disable_logprobs_during_spec_decoding"
:
False
,
},
{
"speculative_model"
:
SPEC_MODEL
,
"disable_logprobs_during_spec_decoding"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_mlp_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify greedy equality with different batch size."""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
539aa992
...
@@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...
@@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model_name"
:
"JackFram/llama-68m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"disable_logprobs_during_spec_decoding"
:
False
,
},
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"disable_logprobs_during_spec_decoding"
:
True
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
8
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"logprobs"
,
[
1
,
6
])
def
test_ngram_e2e_greedy_logprobs
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
,
logprobs
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
,
logprobs
=
logprobs
,
prompt_logprobs
=
logprobs
,
disable_logprobs
=
test_llm_kwargs
[
'disable_logprobs_during_spec_decoding'
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
tests/spec_decode/utils.py
View file @
539aa992
from
array
import
array
from
itertools
import
count
from
itertools
import
count
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
...
@@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
CompletionSequenceGroupOutput
,
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
...
@@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
...
@@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
request_id
=
str
(
i
),
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
{
seq_data
=
{
i
:
i
:
SequenceData
.
from_seqs
(
prompt_token_ids
[:],
SequenceData
(
cont_token_ids
[:]),
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
[:]),
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
cont_token_ids
[:]),
),
},
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
block_tables
=
{
i
:
block_allocations
[
i
][:]},
...
...
tests/test_cache_block_hashing.py
View file @
539aa992
...
@@ -66,8 +66,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
...
@@ -66,8 +66,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes
.
append
([])
hashes
.
append
([])
prompts
=
[
prefix
+
prompt
for
prompt
in
sample_prompts
]
prompts
=
[
prefix
+
prompt
for
prompt
in
sample_prompts
]
seq_id
=
0
for
seq_id
,
prompt
in
enumerate
(
prompts
):
for
prompt
in
prompts
:
hashes
[
-
1
].
append
([])
hashes
[
-
1
].
append
([])
prompt_token_ids
=
tokenizer
.
encode
(
prompt
)
prompt_token_ids
=
tokenizer
.
encode
(
prompt
)
seq
=
Sequence
(
seq_id
,
seq
=
Sequence
(
seq_id
,
...
@@ -83,8 +82,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
...
@@ -83,8 +82,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
for
idx
in
range
(
num_blocks
):
for
idx
in
range
(
num_blocks
):
hashes
[
-
1
][
-
1
].
append
(
seq
.
hash_of_block
(
idx
))
hashes
[
-
1
][
-
1
].
append
(
seq
.
hash_of_block
(
idx
))
seq_id
+=
1
# Check that hashes made with two prefixes with different first blocks are
# Check that hashes made with two prefixes with different first blocks are
# different everywhere.
# different everywhere.
for
hash0
,
hash1
in
zip
(
flatten_2d
(
hashes
[
0
]),
flatten_2d
(
hashes
[
1
])):
for
hash0
,
hash1
in
zip
(
flatten_2d
(
hashes
[
0
]),
flatten_2d
(
hashes
[
1
])):
...
...
tests/test_embedded_commit.py
View file @
539aa992
...
@@ -2,6 +2,7 @@ import vllm
...
@@ -2,6 +2,7 @@ import vllm
def
test_embedded_commit_defined
():
def
test_embedded_commit_defined
():
assert
vllm
.
__commit__
!=
"COMMIT_HASH_PLACEHOLDER"
assert
hasattr
(
vllm
,
"__version__"
)
# 7 characters is the length of a short commit hash
assert
hasattr
(
vllm
,
"__version_tuple__"
)
assert
len
(
vllm
.
__commit__
)
>=
7
assert
vllm
.
__version__
!=
"dev"
assert
vllm
.
__version_tuple__
!=
(
0
,
0
,
"dev"
)
tests/test_logger.py
View file @
539aa992
...
@@ -111,7 +111,7 @@ def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist():
...
@@ -111,7 +111,7 @@ def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist():
configuration occurs."""
configuration occurs."""
with
pytest
.
raises
(
RuntimeError
)
as
ex_info
:
with
pytest
.
raises
(
RuntimeError
)
as
ex_info
:
_configure_vllm_root_logger
()
_configure_vllm_root_logger
()
assert
ex_info
.
type
==
RuntimeError
assert
ex_info
.
type
==
RuntimeError
# noqa: E721
assert
"File does not exist"
in
str
(
ex_info
)
assert
"File does not exist"
in
str
(
ex_info
)
...
@@ -152,7 +152,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
...
@@ -152,7 +152,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
logging_config_file
.
name
):
logging_config_file
.
name
):
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
_configure_vllm_root_logger
()
_configure_vllm_root_logger
()
assert
ex_info
.
type
==
ValueError
assert
ex_info
.
type
==
ValueError
# noqa: E721
assert
"Invalid logging config. Expected Dict, got"
in
str
(
ex_info
)
assert
"Invalid logging config. Expected Dict, got"
in
str
(
ex_info
)
...
...
tests/test_logits_processor.py
View file @
539aa992
import
random
import
random
from
array
import
array
from
typing
import
Tuple
from
typing
import
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -9,8 +8,7 @@ import torch
...
@@ -9,8 +8,7 @@ import torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
...
@@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
temperature
=
0
,
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
...
...
tests/test_sequence.py
View file @
539aa992
from
array
import
array
import
pytest
import
pytest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SequenceData
,
CompletionSequenceGroupOutput
,
SequenceData
,
SequenceOutput
)
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
from
.core.utils
import
create_dummy_prompt
...
@@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
...
@@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
def
test_sequence_data_prefill
():
def
test_sequence_data_prefill
():
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
,
4
])
)
seq_data
=
SequenceData
.
from_seqs
(
[
1
,
2
,
3
,
4
])
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_computed_tokens
()
==
0
assert
seq_data
.
get_num_computed_tokens
()
==
0
# advance by 2
# advance by 2
...
...
tests/tpu/test_custom_dispatcher.py
View file @
539aa992
import
os
from
..utils
import
compare_two_settings
from
..utils
import
compare_two_settings
# --enforce-eager on TPU causes graph compilation
# this times out default Health Check in the MQLLMEngine,
# so we set the timeout here to 30s
os
.
environ
[
"VLLM_RPC_TIMEOUT"
]
=
"30000"
def
test_custom_dispatcher
():
def
test_custom_dispatcher
():
compare_two_settings
(
"google/gemma-2b"
,
compare_two_settings
(
"google/gemma-2b"
,
...
...
tests/utils.py
View file @
539aa992
...
@@ -119,7 +119,7 @@ class RemoteOpenAIServer:
...
@@ -119,7 +119,7 @@ class RemoteOpenAIServer:
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
proc
.
terminate
()
self
.
proc
.
terminate
()
try
:
try
:
self
.
proc
.
wait
(
3
)
self
.
proc
.
wait
(
8
)
except
subprocess
.
TimeoutExpired
:
except
subprocess
.
TimeoutExpired
:
# force kill if needed
# force kill if needed
self
.
proc
.
kill
()
self
.
proc
.
kill
()
...
@@ -493,6 +493,7 @@ async def completions_with_server_args(
...
@@ -493,6 +493,7 @@ async def completions_with_server_args(
'''
'''
outputs
=
None
outputs
=
None
max_wait_seconds
=
240
*
3
# 240 is default
with
RemoteOpenAIServer
(
model_name
,
with
RemoteOpenAIServer
(
model_name
,
server_cli_args
,
server_cli_args
,
max_wait_seconds
=
max_wait_seconds
)
as
server
:
max_wait_seconds
=
max_wait_seconds
)
as
server
:
...
@@ -503,7 +504,7 @@ async def completions_with_server_args(
...
@@ -503,7 +504,7 @@ async def completions_with_server_args(
stream
=
False
,
stream
=
False
,
max_tokens
=
5
,
max_tokens
=
5
,
logprobs
=
num_logprobs
)
logprobs
=
num_logprobs
)
assert
outputs
is
not
None
assert
outputs
is
not
None
,
"Completion API call failed."
return
outputs
return
outputs
...
...
tests/weight_loading/models-large.txt
View file @
539aa992
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
\ No newline at end of file
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
tests/weight_loading/run_model_weight_loading_test.sh
100644 → 100755
View file @
539aa992
File mode changed from 100644 to 100755
tests/worker/test_encoder_decoder_model_runner.py
View file @
539aa992
from
array
import
array
import
itertools
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_cpu
,
make_tensor_with_pad
from
vllm.utils
import
is_cpu
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
_get_graph_batch_size
# CUDA graph scenarios to test
#
# Currently CUDA graph is not supported
ENFORCE_EAGER
=
[
True
]
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
...
@@ -40,8 +35,7 @@ def _create_model_runner(model: str, *args,
...
@@ -40,8 +35,7 @@ def _create_model_runner(model: str, *args,
reason
=
"CPU backend is currently "
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"unsupported for encoder/ "
"decoder models"
)
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_empty_seq_group
():
def
test_empty_seq_group
(
enforce_eager
,
):
"""Verify prepare prompt and decode returns empty output
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""
for empty seq group list"""
...
@@ -52,7 +46,7 @@ def test_empty_seq_group(enforce_eager, ):
...
@@ -52,7 +46,7 @@ def test_empty_seq_group(enforce_eager, ):
max_num_batched_tokens
=
100000
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
True
,
)
)
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input_tensors
(
model_input
=
model_runner
.
_prepare_model_input_tensors
(
...
@@ -85,11 +79,7 @@ def test_empty_seq_group(enforce_eager, ):
...
@@ -85,11 +79,7 @@ def test_empty_seq_group(enforce_eager, ):
"unsupported for encoder/ "
"unsupported for encoder/ "
"decoder models"
)
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
def
test_prepare_prompt
(
batch_size
):
def
test_prepare_prompt
(
batch_size
,
enforce_eager
,
):
'''
'''
Test the ability of the encoder/decoder model runner subclass to
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
produce prefill-phase model inputs & attention metadata.
...
@@ -115,7 +105,7 @@ def test_prepare_prompt(
...
@@ -115,7 +105,7 @@ def test_prepare_prompt(
max_num_batched_tokens
=
100000
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
True
,
)
)
seq_lens
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
...
@@ -127,12 +117,10 @@ def test_prepare_prompt(
...
@@ -127,12 +117,10 @@ def test_prepare_prompt(
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
range
(
seq_len
)))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
encoder_seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -281,11 +269,8 @@ def test_prepare_prompt(
...
@@ -281,11 +269,8 @@ def test_prepare_prompt(
"unsupported for encoder/ "
"unsupported for encoder/ "
"decoder models"
)
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
ENFORCE_EAGER
)
@
pytest
.
mark
.
parametrize
(
"multiple_seqs_per_seq_group"
,
[
True
,
False
])
def
test_prepare_decode
(
def
test_prepare_decode
(
batch_size
,
multiple_seqs_per_seq_group
):
batch_size
,
enforce_eager
,
):
'''
'''
Test the ability of the encoder/decoder model runner subclass to
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
produce decode-phase model inputs & attention metadata.
...
@@ -300,6 +285,7 @@ def test_prepare_decode(
...
@@ -300,6 +285,7 @@ def test_prepare_decode(
Arguments:
Arguments:
* batch_size
* batch_size
* multiple_seqs_per_seq_group
* backend_name: The attention backend under test
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
'''
...
@@ -311,28 +297,33 @@ def test_prepare_decode(
...
@@ -311,28 +297,33 @@ def test_prepare_decode(
max_num_batched_tokens
=
100000
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
True
,
)
)
seq_lens
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
encoder_seq_lens
:
List
[
int
]
=
[]
encoder_seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
block_tables
=
{
0
:
[
1
],
1
:
[
3
]
}
if
multiple_seqs_per_seq_group
else
{
0
:
[
1
]
}
cross_block_table
=
[
2
]
cross_block_table
=
[
2
]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
is_prompt
=
False
,
seq_data
=
{
0
:
seq_data
},
seq_data
=
{
0
:
seq_data
,
1
:
seq_data
}
if
multiple_seqs_per_seq_group
else
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
encoder_seq_data
=
encoder_seq_data
,
...
@@ -340,6 +331,10 @@ def test_prepare_decode(
...
@@ -340,6 +331,10 @@ def test_prepare_decode(
)
)
assert
seq_group_metadata
.
token_chunk_size
==
1
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_lens
.
extend
(
[
seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
encoder_seq_lens
.
extend
(
[
encoder_seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
# Build
# Build
# * Decoder model inputs
# * Decoder model inputs
...
@@ -410,25 +405,31 @@ def test_prepare_decode(
...
@@ -410,25 +405,31 @@ def test_prepare_decode(
# Verify block tables are correct for prompts
# Verify block tables are correct for prompts
# - Decoder self-attention
# - Decoder self-attention
expected
=
torch
.
tensor
(
flattened_block_tables
=
[
[
block_tables
[
0
]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
block_table
for
block_table
in
block_tables
.
values
()
dtype
=
torch
.
int32
,
]
device
=
model_runner
.
device
)
expected
=
torch
.
tensor
(
flattened_block_tables
*
len
(
seq_group_metadata_list
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
equal
(
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
attn_metadata
.
block_tables
,
expected
,
expected
,
)
)
# - Encoder/decoder cross-attention
# - Encoder/decoder cross-attention
expected
=
torch
.
tensor
(
expected
=
torch
.
tensor
([
[
cross_block_table
for
_
in
range
(
len
(
seq_group_metadata_list
))],
cross_block_table
for
seq_group_metadata
in
seq_group_metadata_list
dtype
=
torch
.
int32
,
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))
device
=
model_runner
.
device
)
],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
equal
(
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
attn_metadata
.
cross_block_tables
,
expected
,
expected
,
)
)
# Cuda graph should is currently not supported for encoder/decoer.
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert
attn_metadata
.
use_cuda_graph
is
False
assert
attn_metadata
.
use_cuda_graph
is
False
# Verify the lengths of input tokens & positions
# Verify the lengths of input tokens & positions
...
@@ -464,8 +465,7 @@ def test_prepare_decode(
...
@@ -464,8 +465,7 @@ def test_prepare_decode(
# each sequence) in the decode phase
# each sequence) in the decode phase
expected_selected_token_indices
=
[]
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
selected_token_start_idx
,
seq_len
in
enumerate
(
seq_lens
):
for
seq_len
in
seq_lens
:
# Compute the index offset of the final token in each
# Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is
# sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length
# decoded per iteration per sequence, then the length
...
@@ -474,7 +474,6 @@ def test_prepare_decode(
...
@@ -474,7 +474,6 @@ def test_prepare_decode(
# generated tokens is 0 (i.e. the expected sampling index
# generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`)
# for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
sampling_metadata
=
model_input
.
sampling_metadata
sampling_metadata
=
model_input
.
sampling_metadata
actual
=
sampling_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
...
@@ -484,3 +483,170 @@ def test_prepare_decode(
...
@@ -484,3 +483,170 @@ def test_prepare_decode(
dtype
=
actual
.
dtype
,
dtype
=
actual
.
dtype
,
)
)
assert
torch
.
equal
(
actual
,
expected
)
assert
torch
.
equal
(
actual
,
expected
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
@
pytest
.
mark
.
parametrize
(
"multiple_seqs_per_seq_group"
,
[
True
,
False
])
def
test_prepare_decode_cuda_graph
(
batch_size
,
multiple_seqs_per_seq_group
):
"""
Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded
for varying input batch sizes.
"""
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
False
,
)
block_tables
=
{
0
:
[
1
],
1
:
[
3
]
}
if
multiple_seqs_per_seq_group
else
{
0
:
[
1
]
}
seq_lens
:
List
[
int
]
=
[]
encoder_seq_lens
:
List
[
int
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
cross_block_table
=
[
2
]
expanded_batch_size
=
0
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
seq_data
,
1
:
seq_data
}
if
multiple_seqs_per_seq_group
else
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
)
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_lens
.
extend
(
[
seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
encoder_seq_lens
.
extend
(
[
encoder_seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
expanded_batch_size
=
expanded_batch_size
+
len
(
seq_group_metadata
.
seq_data
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
attn_metadata
.
slot_mapping
encoder_input_tokens
=
model_input
.
encoder_input_tokens
encoder_input_positions
=
model_input
.
encoder_input_positions
cross_slot_mapping
=
attn_metadata
.
cross_slot_mapping
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size
=
_get_graph_batch_size
(
expanded_batch_size
)
cuda_graph_pad_size
=
graph_batch_size
-
expanded_batch_size
padded_seq_lens
=
seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
padded_encoder_seq_lens
=
encoder_seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
assert
return_seq_lens
==
padded_seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
cross_slot_mapping
)
==
len
(
encoder_input_tokens
)
# Verify attention metadata
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_decode_tokens
>
0
assert
torch
.
equal
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
padded_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
padded_seq_lens
assert
attn_metadata
.
max_prefill_seq_len
==
0
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
# - Encoder attention metadata
assert
attn_metadata
.
encoder_seq_lens
==
padded_encoder_seq_lens
assert
torch
.
equal
(
attn_metadata
.
encoder_seq_lens_tensor
,
torch
.
tensor
(
padded_encoder_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_encoder_seq_len
==
max
(
padded_encoder_seq_lens
)
assert
attn_metadata
.
num_encoder_tokens
==
sum
(
padded_encoder_seq_lens
)
# Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected.
flattened_block_tables
=
[
block_table
for
_
in
range
(
len
(
seq_group_metadata_list
))
for
block_table
in
block_tables
.
values
()
]
flattened_block_tables
.
extend
([[]
for
_
in
range
(
cuda_graph_pad_size
)])
expected
=
make_tensor_with_pad
(
flattened_block_tables
,
max_len
=
64
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
expected
,
)
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected.
expected
=
[
cross_block_table
for
seq_group_metadata
in
seq_group_metadata_list
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))
]
expected
.
extend
([[]
for
_
in
range
(
cuda_graph_pad_size
)])
expected
=
make_tensor_with_pad
(
expected
,
max_len
=
64
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
expected
,
)
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert
attn_metadata
.
use_cuda_graph
is
True
# Verify the lengths of input tokens & positions
# - Decoder
assert
len
(
input_tokens
)
==
len
(
padded_seq_lens
)
assert
len
(
input_positions
)
==
len
(
padded_seq_lens
)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
input_tokens
,
input_positions
,
)
# - Encoder
assert
len
(
encoder_input_tokens
)
==
0
assert
len
(
encoder_input_tokens
)
==
0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
encoder_input_tokens
,
encoder_input_positions
,
)
tests/worker/test_model_runner.py
View file @
539aa992
from
array
import
array
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
...
@@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
...
@@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
@@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
...
@@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
context_lens
.
append
(
context_len
)
seq_data
=
SequenceData
(
seq_data
=
SequenceData
.
from_seqs
(
range
(
context_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
)))
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
...
@@ -241,10 +237,8 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -241,10 +237,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling
# Verify Sampling
expected_selected_token_indices
=
[]
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
selected_token_start_idx
,
_
in
enumerate
(
context_lens
):
for
_
in
context_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
selected_token_start_idx
+=
1
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_group_metadata_list
,
seq_lens
,
seq_lens
,
...
@@ -328,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -328,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -345,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -345,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
))
seq_data
=
SequenceData
.
from_seqs
(
range
(
context_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
...
...
use_existing_torch.py
0 → 100644
View file @
539aa992
import
glob
requires_files
=
glob
.
glob
(
'requirements*.txt'
)
requires_files
+=
[
"pyproject.toml"
]
for
file
in
requires_files
:
print
(
f
">>> cleaning
{
file
}
"
)
with
open
(
file
,
'r'
)
as
f
:
lines
=
f
.
readlines
()
if
"torch"
in
""
.
join
(
lines
).
lower
():
print
(
"removed:"
)
with
open
(
file
,
'w'
)
as
f
:
for
line
in
lines
:
if
'torch'
not
in
line
.
lower
():
f
.
write
(
line
)
else
:
print
(
line
.
strip
())
print
(
f
"<<< done cleaning
{
file
}
"
)
print
()
vllm/__init__.py
View file @
539aa992
...
@@ -11,11 +11,12 @@ from vllm.outputs import (CompletionOutput, EmbeddingOutput,
...
@@ -11,11 +11,12 @@ from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput
,
RequestOutput
)
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__commit__
,
__version__
,
__dcu_version__
from
vllm.version
import
__version__
,
__version_tuple__
,
__dcu_version__
__all__
=
[
__all__
=
[
"__commit__"
,
"__version__"
,
"__version__"
,
"__version_tuple__"
,
"LLM"
,
"LLM"
,
"ModelRegistry"
,
"ModelRegistry"
,
"PromptInputs"
,
"PromptInputs"
,
...
...
vllm/_custom_ops.py
View file @
539aa992
...
@@ -22,8 +22,13 @@ if not current_platform.is_tpu():
...
@@ -22,8 +22,13 @@ if not current_platform.is_tpu():
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
if
current_platform
.
is_rocm
():
import
vllm._rocm_C
# noqa: F401
supports_moe_ops
=
False
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
import
vllm._moe_C
# noqa: F401
import
vllm._moe_C
# noqa: F401
supports_moe_ops
=
True
def
hint_on_error
(
fn
):
def
hint_on_error
(
fn
):
...
@@ -204,8 +209,34 @@ def paged_attention_v2_opt(
...
@@ -204,8 +209,34 @@ def paged_attention_v2_opt(
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_rocm
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
torch
.
ops
.
_rocm_C
.
paged_attention
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
)
# pos encoding ops
# pos encoding ops
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -359,9 +390,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -359,9 +390,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_g_idx, use_exllama, bit)
# b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this
if
hasattr
(
torch
.
ops
.
_C
,
"gptq_gemm"
):
try
:
torch
.
ops
.
_C
.
gptq_gemm
# noqa B018
@
torch
.
library
.
register_fake
(
"_C::gptq_gemm"
)
@
torch
.
library
.
register_fake
(
"_C::gptq_gemm"
)
def
_gptq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
_gptq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
...
@@ -371,8 +400,6 @@ try:
...
@@ -371,8 +400,6 @@ try:
return
torch
.
empty
((
a
.
size
(
0
),
b_q_weight
.
size
(
1
)),
return
torch
.
empty
((
a
.
size
(
0
),
b_q_weight
.
size
(
1
)),
dtype
=
a
.
dtype
,
dtype
=
a
.
dtype
,
device
=
a
.
device
)
device
=
a
.
device
)
except
Exception
:
pass
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
...
@@ -399,9 +426,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -399,9 +426,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n
,
size_k
)
size_n
,
size_k
)
# TODO: has to be a better way to do this
if
hasattr
(
torch
.
ops
.
_C
,
"gptq_marlin_24_gemm"
):
try
:
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
# noqa B018
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_24_gemm"
)
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_24_gemm"
)
def
_gptq_marlin_24_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
_gptq_marlin_24_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
...
@@ -527,8 +552,8 @@ try:
...
@@ -527,8 +552,8 @@ try:
@
torch
.
library
.
register_fake
(
"_C::machete_gemm"
)
@
torch
.
library
.
register_fake
(
"_C::machete_gemm"
)
def
machete_gemm_fake
(
def
machete_gemm_fake
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b_q
:
torch
.
# Should be the tensor returned by machete_prepack_B
Tensor
,
# Should be the tensor returned by machete_prepack_B
b_q
:
torch
.
Tensor
,
b_type
:
ScalarType
,
b_type
:
ScalarType
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -545,7 +570,8 @@ try:
...
@@ -545,7 +570,8 @@ try:
@
torch
.
library
.
register_fake
(
"_C::machete_prepack_B"
)
@
torch
.
library
.
register_fake
(
"_C::machete_prepack_B"
)
def
machete_prepack_B_fake
(
b_q_weight
:
torch
.
Tensor
,
def
machete_prepack_B_fake
(
b_q_weight
:
torch
.
Tensor
,
b_type
:
ScalarType
)
->
torch
.
Tensor
:
b_type
:
ScalarType
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b_q_weight
)
return
torch
.
empty_like
(
b_q_weight
,
memory_format
=
torch
.
contiguous_format
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_fwd"
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_fwd"
)
def
causal_conv1d_fwd_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
causal_conv1d_fwd_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -557,10 +583,10 @@ try:
...
@@ -557,10 +583,10 @@ try:
return
torch
.
empty_like
(
x
)
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_update"
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_update"
)
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
def
causal_conv1d_update_fake
(
weight
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
silu_activation
:
bool
)
->
torch
.
Tensor
:
conv_state_indices
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::selective_scan_fwd"
)
@
torch
.
library
.
register_fake
(
"_C::selective_scan_fwd"
)
...
@@ -571,20 +597,11 @@ try:
...
@@ -571,20 +597,11 @@ try:
delta_softplus
:
bool
,
index_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
index_
:
Optional
[
torch
.
Tensor
],
x
:
Optional
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
x
:
Optional
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
a
=
torch
.
empty_like
(
u
)
a
=
torch
.
empty_like
(
u
)
if
x
is
not
None
:
b
=
x
else
:
b
=
torch
.
empty
((
u
.
size
(
0
),
u
.
size
(
1
),
A
.
size
(
1
)),
dtype
=
u
.
dtype
,
device
=
u
.
device
)
if
z_
is
not
None
:
if
z_
is
not
None
:
c
=
torch
.
empty_like
(
z_
)
c
=
torch
.
empty_like
(
z_
)
return
[
a
,
b
,
c
]
return
[
a
,
c
]
else
:
else
:
return
[
a
,
b
]
return
[
a
]
except
Exception
:
pass
# cutlass
# cutlass
...
@@ -668,7 +685,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
...
@@ -668,7 +685,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
num_bits
:
int
)
->
torch
.
Tensor
:
num_bits
:
int
)
->
torch
.
Tensor
:
num_experts
=
b_q_weight
.
shape
[
0
]
num_experts
=
b_q_weight
.
shape
[
0
]
assert
size_k
%
16
==
0
assert
size_k
%
16
==
0
output
=
torch
.
empty
((
num_experts
,
size_k
//
16
,
size_n
*
2
),
output
=
torch
.
empty
((
num_experts
,
size_k
//
16
,
size_n
*
(
num_bits
//
2
)
),
device
=
b_q_weight
.
device
,
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
)
dtype
=
b_q_weight
.
dtype
)
for
e
in
range
(
num_experts
):
for
e
in
range
(
num_experts
):
...
@@ -732,6 +749,18 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
...
@@ -732,6 +749,18 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
)
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
)
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
@
torch
.
library
.
register_fake
(
"_C::permute_cols"
)
def
_permute_cols_fake
(
a
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
a
)
def
permute_cols
(
a
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
permute_cols
(
a
,
perm
)
# fp8
# fp8
# def scaled_fp8_quant(
# def scaled_fp8_quant(
# input: torch.Tensor,
# input: torch.Tensor,
...
@@ -793,32 +822,43 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
...
@@ -793,32 +822,43 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
# int8
# int8
def
scaled_int8_quant
(
def
scaled_int8_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
azp
:
Optional
[
torch
.
Tensor
]
=
None
,
symmetric
:
bool
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
"""
Quantize the input tensor to int8 and return the quantized tensor and scale.
Quantize the input tensor to int8 and return the quantized tensor and scale
, and maybe azp
.
Args:
Args:
input: The input tensor to be quantized to int8.
input: The input tensor to be quantized to int8.
scale: Optional scaling factor for the int8 quantization.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
When not provided, we invoke dynamic-per-token quantization.
azp: Optional zero-point for the int8 quantization.
Must be provided for asymmetric quantization if `scale` is provided.
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
Returns:
Tuple[
T
orch.Tensor,
T
orch.Tensor] : Output int8 tensor
and scales
.
Tuple[
t
orch.Tensor,
torch.Tensor, Optional[t
orch.Tensor]
]
: Output int8 tensor
, scales, and optionally azp
.
"""
"""
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
int8
)
output
=
torch
.
empty_like
(
input
,
dtype
=
torch
.
int8
)
if
scale
is
not
None
:
if
scale
is
not
None
:
# static-per-tensor quantization.
# static-per-tensor quantization.
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
output
,
input
,
scale
)
assert
symmetric
==
(
return
output
,
scale
azp
is
None
),
"azp must only be provided for asymmetric quantization."
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
output
,
input
,
scale
,
azp
)
return
output
,
scale
,
None
# dynamic-per-token quantization.
# dynamic-per-token quantization.
input_scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
input_scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
output
,
input
,
input_scales
)
input_azp
=
None
if
symmetric
else
torch
.
empty_like
(
input_scales
,
return
output
,
input_scales
dtype
=
torch
.
int32
)
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
(
output
,
input
,
input_scales
,
input_azp
)
return
output
,
input_scales
,
input_azp
# qqq ops
# qqq ops
...
@@ -866,11 +906,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
...
@@ -866,11 +906,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
silu_activation
)
silu_activation
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
def
causal_conv1d_update
(
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
x
:
torch
.
Tensor
,
silu_activation
:
bool
)
->
torch
.
Tensor
:
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
conv_state_indices
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
return
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
)
silu_activation
,
conv_state_indices
)
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
...
@@ -901,6 +947,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
...
@@ -901,6 +947,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies
,
gating_output
)
token_expert_indicies
,
gating_output
)
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
@
torch
.
library
.
register_fake
(
"_moe_C::marlin_gemm_moe"
)
def
marlin_gemm_moe_fake
(
a
:
torch
.
Tensor
,
b_q_weights
:
torch
.
Tensor
,
sorted_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
num_experts
:
int
,
topk
:
int
,
moe_block_size
:
int
,
replicate_input
:
bool
,
apply_weights
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
topk
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
def
reshape_and_cache
(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
@@ -968,12 +1032,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
...
@@ -968,12 +1032,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
offsets
,
rank
,
full_nvlink
)
offsets
,
rank
,
full_nvlink
)
def
should_custom_ar
(
inp
:
torch
.
Tensor
,
max_size
:
int
,
world_size
:
int
,
full_nvlink
:
bool
)
->
bool
:
return
torch
.
ops
.
_C_custom_ar
.
should_custom_ar
(
inp
,
max_size
,
world_size
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment