Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4851c202
Commit
4851c202
authored
Sep 13, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1' into v0.6.1-dev
parents
9b902f9e
3fd2b0d2
Changes
203
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
594 additions
and
70 deletions
+594
-70
tests/async_engine/test_chat_template.py
tests/async_engine/test_chat_template.py
+3
-2
tests/compile/test_wrapper.py
tests/compile/test_wrapper.py
+2
-2
tests/conftest.py
tests/conftest.py
+61
-7
tests/distributed/test_multimodal_broadcast.py
tests/distributed/test_multimodal_broadcast.py
+4
-2
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+29
-17
tests/engine/test_multiproc_workers.py
tests/engine/test_multiproc_workers.py
+3
-3
tests/entrypoints/llm/test_generate_multiple_loras.py
tests/entrypoints/llm/test_generate_multiple_loras.py
+1
-1
tests/entrypoints/openai/test_run_batch.py
tests/entrypoints/openai/test_run_batch.py
+3
-1
tests/entrypoints/openai/test_serving_engine.py
tests/entrypoints/openai/test_serving_engine.py
+107
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+19
-3
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+16
-0
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+22
-0
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+13
-0
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+2
-1
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+15
-0
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+8
-0
tests/kernels/test_machete_gemm.py
tests/kernels/test_machete_gemm.py
+7
-0
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+29
-26
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+217
-4
tests/kernels/utils.py
tests/kernels/utils.py
+33
-1
No files found.
tests/async_engine/test_chat_template.py
View file @
4851c202
import
pytest
from
vllm.entrypoints.chat_utils
import
apply_chat_template
,
load_chat_template
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
load_chat_template
)
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
...
...
@@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt
=
add_generation_prompt
)
# Call the function and get the result
result
=
apply_chat_template
(
result
=
apply_
hf_
chat_template
(
tokenizer
,
conversation
=
mock_request
.
messages
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
...
...
tests/compile/test_wrapper.py
View file @
4851c202
...
...
@@ -2,7 +2,7 @@ from typing import Optional
import
torch
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispa
c
ther
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispat
c
her
class
MyMod
(
torch
.
nn
.
Module
):
...
...
@@ -13,7 +13,7 @@ class MyMod(torch.nn.Module):
return
x
*
2
class
MyWrapper
(
TorchCompileWrapperWithCustomDispa
c
ther
):
class
MyWrapper
(
TorchCompileWrapperWithCustomDispat
c
her
):
def
__init__
(
self
,
model
):
self
.
model
=
model
...
...
tests/conftest.py
View file @
4851c202
...
...
@@ -21,6 +21,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.video
import
VideoAsset
from
vllm.config
import
TokenizerPoolConfig
from
vllm.connections
import
global_http_connection
from
vllm.distributed
import
(
destroy_distributed_environment
,
...
...
@@ -44,6 +45,7 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
PromptImageInput
=
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]
PromptAudioInput
=
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]
PromptVideoInput
=
Union
[
List
[
np
.
ndarray
],
List
[
List
[
np
.
ndarray
]]]
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
...
...
@@ -85,8 +87,35 @@ class _ImageAssets(_ImageAssetsBase):
return
[
prompts
[
"stop_sign"
],
prompts
[
"cherry_blossom"
]]
class
_VideoAssetPrompts
(
TypedDict
):
sample_demo_1
:
str
if
sys
.
version_info
<
(
3
,
9
):
# UserList cannot be subscripted
class
_VideoAssetsBase
(
UserList
):
pass
else
:
class
_VideoAssetsBase
(
UserList
[
VideoAsset
]):
pass
class
_VideoAssets
(
_VideoAssetsBase
):
def
__init__
(
self
)
->
None
:
super
().
__init__
([
VideoAsset
(
"sample_demo_1.mp4"
),
])
def
prompts
(
self
,
prompts
:
_VideoAssetPrompts
)
->
List
[
str
]:
return
[
prompts
[
"sample_demo_1"
]]
IMAGE_ASSETS
=
_ImageAssets
()
"""Singleton instance of :class:`_ImageAssets`."""
VIDEO_ASSETS
=
_VideoAssets
()
"""Singleton instance of :class:`_VideoAssets`."""
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -202,6 +231,11 @@ def image_assets() -> _ImageAssets:
return
IMAGE_ASSETS
@
pytest
.
fixture
(
scope
=
"session"
)
def
video_assets
()
->
_VideoAssets
:
return
VIDEO_ASSETS
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
,
BatchFeature
)
...
...
@@ -278,7 +312,8 @@ class HfRunner:
def
generate
(
self
,
prompts
:
List
[
str
],
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
if
images
:
...
...
@@ -292,6 +327,8 @@ class HfRunner:
}
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
images
[
i
]
if
videos
is
not
None
and
videos
[
i
]
is
not
None
:
processor_kwargs
[
"videos"
]
=
videos
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
postprocess_inputs
(
inputs
)
...
...
@@ -314,7 +351,7 @@ class HfRunner:
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
=
self
.
generate
(
prompts
,
...
...
@@ -351,7 +388,8 @@ class HfRunner:
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
List
[
torch
.
Tensor
]]:
all_logprobs
:
List
[
List
[
torch
.
Tensor
]]
=
[]
...
...
@@ -362,6 +400,8 @@ class HfRunner:
}
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
images
[
i
]
if
videos
is
not
None
and
videos
[
i
]
is
not
None
:
processor_kwargs
[
"videos"
]
=
videos
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
postprocess_inputs
(
inputs
)
...
...
@@ -433,8 +473,9 @@ class HfRunner:
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
audios
:
Optional
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]]:
all_logprobs
:
List
[
List
[
Dict
[
int
,
float
]]]
=
[]
...
...
@@ -454,6 +495,8 @@ class HfRunner:
processor_kwargs
[
"audio"
]
=
audio
processor_kwargs
[
"sampling_rate"
]
=
sr
if
videos
is
not
None
:
processor_kwargs
[
"videos"
]
=
videos
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
postprocess_inputs
(
inputs
)
...
...
@@ -634,12 +677,16 @@ class VllmRunner:
sampling_params
:
SamplingParams
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
assert
sampling_params
.
logprobs
is
not
None
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
if
videos
is
not
None
:
assert
len
(
prompts
)
==
len
(
videos
)
inputs
=
[
TextPrompt
(
prompt
=
prompt
)
for
prompt
in
prompts
]
if
images
is
not
None
:
for
i
,
image
in
enumerate
(
images
):
...
...
@@ -649,6 +696,11 @@ class VllmRunner:
for
i
,
audio
in
enumerate
(
audios
):
inputs
[
i
][
"multi_modal_data"
]
=
{
"audio"
:
audio
}
if
videos
is
not
None
:
for
i
,
video
in
enumerate
(
videos
):
inputs
[
i
][
"multi_modal_data"
]
=
{
"video"
:
video
}
print
(
f
"[INPUTS!!!!]:
{
inputs
}
,
{
sampling_params
}
"
)
req_outputs
=
self
.
model
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
return
self
.
_final_steps_generate_w_logprobs
(
req_outputs
)
...
...
@@ -671,7 +723,7 @@ class VllmRunner:
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
outputs
=
self
.
generate
(
prompts
,
greedy_params
,
images
=
images
)
...
...
@@ -685,6 +737,7 @@ class VllmRunner:
num_logprobs
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
...
...
@@ -694,7 +747,8 @@ class VllmRunner:
outputs
=
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
images
=
images
,
audios
=
audios
)
audios
=
audios
,
videos
=
videos
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
...
...
tests/distributed/test_multimodal_broadcast.py
View file @
4851c202
...
...
@@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
if
model
.
startswith
(
"llava-hf/llava-1.5"
):
from
..models.test_llava
import
models
,
run_test
elif
model
.
startswith
(
"llava-hf/llava-v1.6"
):
from
..models.test_llava_next
import
models
,
run_test
from
..models.test_llava_next
import
run_test
# type: ignore[no-redef]
from
..models.test_llava_next
import
models
elif
model
.
startswith
(
"facebook/chameleon"
):
from
..models.test_chameleon
import
models
,
run_test
from
..models.test_chameleon
import
run_test
# type: ignore[no-redef]
from
..models.test_chameleon
import
models
else
:
raise
NotImplementedError
(
f
"Unsupported model:
{
model
}
"
)
...
...
tests/distributed/test_pipeline_parallel.py
View file @
4851c202
...
...
@@ -18,23 +18,28 @@ logger = init_logger("test_pipeline_parallel")
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
@
pytest
.
mark
.
parametrize
((
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
@
pytest
.
mark
.
parametrize
(
(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"
),
[
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
])
(
2
,
2
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-1B"
,
"ray"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-2B"
,
"ray"
),
(
1
,
2
,
1
,
0
,
1
,
"OpenGVLab/InternVL2-4B"
,
"ray"
),
],
)
@
fork_new_process_for_each_test
def
test_compare_tp
(
TP_SIZE
,
PP_SIZE
,
EAGER_MODE
,
CHUNKED_PREFILL
,
MODEL_NAME
,
DIST_BACKEND
):
def
test_compare_tp
(
TP_SIZE
,
PP_SIZE
,
EAGER_MODE
,
CHUNKED_PREFILL
,
TRUST_REMOTE_CODE
,
MODEL_NAME
,
DIST_BACKEND
):
if
VLLM_MULTI_NODE
and
DIST_BACKEND
==
"mp"
:
pytest
.
skip
(
"Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend"
)
...
...
@@ -43,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"float16"
,
"--max-model-len"
,
"8192"
,
"--pipeline-parallel-size"
,
str
(
PP_SIZE
),
"--tensor-parallel-size"
,
...
...
@@ -59,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
tp_args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"float16"
,
"--max-model-len"
,
"8192"
,
"--tensor-parallel-size"
,
str
(
max
(
TP_SIZE
,
2
)),
# We only use 2 GPUs in the CI.
"--distributed-executor-backend"
,
...
...
@@ -71,6 +80,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if
EAGER_MODE
:
pp_args
.
append
(
"--enforce-eager"
)
tp_args
.
append
(
"--enforce-eager"
)
if
TRUST_REMOTE_CODE
:
pp_args
.
append
(
"--trust-remote-code"
)
tp_args
.
append
(
"--trust-remote-code"
)
pp_env
=
None
if
(
DIST_BACKEND
==
"ray"
and
TP_SIZE
==
2
and
PP_SIZE
==
2
and
CHUNKED_PREFILL
):
...
...
tests/engine/test_multiproc_workers.py
View file @
4851c202
...
...
@@ -83,7 +83,7 @@ def test_local_workers() -> None:
workers
[
3
].
process
.
kill
()
# Other workers should get shut down here
worker_monitor
.
join
(
2
)
worker_monitor
.
join
(
2
0
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
...
...
@@ -108,7 +108,7 @@ def test_local_workers_clean_shutdown() -> None:
# Clean shutdown
worker_monitor
.
close
()
worker_monitor
.
join
(
5
)
worker_monitor
.
join
(
20
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
...
...
@@ -161,7 +161,7 @@ async def test_local_workers_async() -> None:
workers
[
3
].
process
.
kill
()
# Other workers should get shut down here
worker_monitor
.
join
(
2
)
worker_monitor
.
join
(
2
0
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
...
...
tests/entrypoints/llm/test_generate_multiple_loras.py
View file @
4851c202
...
...
@@ -50,7 +50,7 @@ def zephyr_lora_files():
@
pytest
.
mark
.
skip_global_cleanup
def
test_multiple_lora_requests
(
llm
:
LLM
,
zephyr_lora_files
):
lora_request
=
[
LoRARequest
(
LORA_NAME
,
idx
+
1
,
zephyr_lora_files
)
LoRARequest
(
LORA_NAME
+
str
(
idx
)
,
idx
+
1
,
zephyr_lora_files
)
for
idx
in
range
(
len
(
PROMPTS
))
]
# Multiple SamplingParams should be matched with each prompt
...
...
tests/entrypoints/openai/test_run_batch.py
View file @
4851c202
...
...
@@ -8,7 +8,9 @@ from vllm.entrypoints.openai.protocol import BatchRequestOutput
INPUT_BATCH
=
"""{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-4", "method": "POST", "url": "/bad_url", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {"stream": "True", "model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
INVALID_INPUT_BATCH
=
"""{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
...
...
tests/entrypoints/openai/test_serving_engine.py
0 → 100644
View file @
4851c202
from
http
import
HTTPStatus
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
LoadLoraAdapterRequest
,
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
MODEL_NAME
=
"meta-llama/Llama-2-7b"
LORA_LOADING_SUCCESS_MESSAGE
=
(
"Success: LoRA adapter '{lora_name}' added successfully."
)
LORA_UNLOADING_SUCCESS_MESSAGE
=
(
"Success: LoRA adapter '{lora_name}' removed successfully."
)
async
def
_async_serving_engine_init
():
mock_engine_client
=
MagicMock
(
spec
=
AsyncEngineClient
)
mock_model_config
=
MagicMock
(
spec
=
ModelConfig
)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config
.
max_model_len
=
2048
serving_engine
=
OpenAIServing
(
mock_engine_client
,
mock_model_config
,
served_model_names
=
[
MODEL_NAME
],
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
None
)
return
serving_engine
@
pytest
.
mark
.
asyncio
async
def
test_load_lora_adapter_success
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter"
,
lora_path
=
"/path/to/adapter2"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
response
==
LORA_LOADING_SUCCESS_MESSAGE
.
format
(
lora_name
=
'adapter'
)
assert
len
(
serving_engine
.
lora_requests
)
==
1
assert
serving_engine
.
lora_requests
[
0
].
lora_name
==
"adapter"
@
pytest
.
mark
.
asyncio
async
def
test_load_lora_adapter_missing_fields
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
""
,
lora_path
=
""
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
@
pytest
.
mark
.
asyncio
async
def
test_load_lora_adapter_duplicate
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter1"
,
lora_path
=
"/path/to/adapter1"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
response
==
LORA_LOADING_SUCCESS_MESSAGE
.
format
(
lora_name
=
'adapter1'
)
assert
len
(
serving_engine
.
lora_requests
)
==
1
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter1"
,
lora_path
=
"/path/to/adapter1"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
assert
len
(
serving_engine
.
lora_requests
)
==
1
@
pytest
.
mark
.
asyncio
async
def
test_unload_lora_adapter_success
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter1"
,
lora_path
=
"/path/to/adapter1"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
len
(
serving_engine
.
lora_requests
)
==
1
request
=
UnloadLoraAdapterRequest
(
lora_name
=
"adapter1"
)
response
=
await
serving_engine
.
unload_lora_adapter
(
request
)
assert
response
==
LORA_UNLOADING_SUCCESS_MESSAGE
.
format
(
lora_name
=
'adapter1'
)
assert
len
(
serving_engine
.
lora_requests
)
==
0
@
pytest
.
mark
.
asyncio
async
def
test_unload_lora_adapter_missing_fields
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
UnloadLoraAdapterRequest
(
lora_name
=
""
,
lora_int_id
=
None
)
response
=
await
serving_engine
.
unload_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
@
pytest
.
mark
.
asyncio
async
def
test_unload_lora_adapter_not_found
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
UnloadLoraAdapterRequest
(
lora_name
=
"nonexistent_adapter"
)
response
=
await
serving_engine
.
unload_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
tests/kernels/test_activation.py
View file @
4851c202
...
...
@@ -3,8 +3,10 @@ from typing import Type
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
GeluAndMul
,
NewGELU
,
SiluAndMul
)
NewGELU
,
QuickGELU
,
SiluAndMul
)
from
.allclose_default
import
get_default_atol
,
get_default_rtol
...
...
@@ -39,18 +41,28 @@ def test_act_and_mul(
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
if
activation
==
"silu"
:
layer
=
SiluAndMul
()
fn
=
torch
.
ops
.
_C
.
silu_and_mul
elif
activation
==
"gelu"
:
layer
=
GeluAndMul
(
approximate
=
"none"
)
fn
=
torch
.
ops
.
_C
.
gelu_and_mul
elif
activation
==
"gelu_tanh"
:
layer
=
GeluAndMul
(
approximate
=
"tanh"
)
fn
=
torch
.
ops
.
_C
.
gelu_tanh_and_mul
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
opcheck
(
fn
,
(
out
,
x
))
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
FastGELU
,
NewGELU
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[(
FastGELU
,
torch
.
ops
.
_C
.
gelu_fast
),
(
NewGELU
,
torch
.
ops
.
_C
.
gelu_new
),
(
QuickGELU
,
torch
.
ops
.
_C
.
gelu_quick
)])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -70,10 +82,14 @@ def test_activation(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
activation
()
layer
=
activation
[
0
]()
fn
=
activation
[
1
]
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
out
=
torch
.
empty_like
(
x
)
opcheck
(
fn
,
(
out
,
x
))
tests/kernels/test_attention.py
View file @
4851c202
...
...
@@ -6,6 +6,7 @@ import torch
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
...
...
@@ -199,6 +200,13 @@ def test_paged_attention(
k_scale
,
v_scale
,
)
opcheck
(
torch
.
ops
.
_C
.
paged_attention_v1
,
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
elif
version
==
"v2"
:
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
...
...
@@ -231,6 +239,14 @@ def test_paged_attention(
k_scale
,
v_scale
,
)
opcheck
(
torch
.
ops
.
_C
.
paged_attention_v2
,
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
...
...
tests/kernels/test_cache.py
View file @
4851c202
...
...
@@ -4,6 +4,7 @@ from typing import List, Tuple
import
pytest
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
vllm
import
_custom_ops
as
ops
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
...
...
@@ -88,6 +89,11 @@ def test_copy_blocks(
block_mapping_tensor
=
torch
.
tensor
(
block_mapping
,
dtype
=
torch
.
int64
,
device
=
device
).
view
(
-
1
,
2
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
(
key_caches
,
value_caches
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
# Run the reference implementation.
...
...
@@ -163,6 +169,10 @@ def test_reshape_and_cache(
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
...
...
@@ -270,6 +280,10 @@ def test_reshape_and_cache_flash(
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
...
...
@@ -367,6 +381,14 @@ def test_swap_blocks(
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
# Call the swap_blocks kernel.
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
...
...
tests/kernels/test_cutlass.py
View file @
4851c202
...
...
@@ -7,6 +7,7 @@ from typing import Optional, Type
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
...
...
@@ -108,6 +109,9 @@ def cutlass_int8_gemm_helper(m: int,
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
4096
,
8192
,
16384
,
24576
,
256
,
1024
])
...
...
@@ -341,6 +345,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
if
azp_per_token
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_adj_i32
,
azp_i32
,
func_bias
))
else
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_with_adj_i32
,
None
,
func_bias
))
# Test working with a subset of A and B
def
test_cutlass_subset
():
...
...
tests/kernels/test_flashinfer.py
View file @
4851c202
...
...
@@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
)
data_type
=
dtype
,
q_data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
kv_cache_fp8
,
logits_soft_cap
=
soft_cap
,
...
...
tests/kernels/test_int8_quant.py
View file @
4851c202
...
...
@@ -2,6 +2,7 @@ import pytest
import
torch
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
scaled_int8_quant
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -12,6 +13,16 @@ SEEDS = [0]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
def
opcheck_int8_quant
(
output
,
input
,
scale
=
None
):
if
scale
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
))
else
:
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
))
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -34,6 +45,8 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
# big atol to account for rounding errors
opcheck_int8_quant
(
ops_out
,
x
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
...
...
@@ -58,3 +71,5 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
# big atol to account for rounding errors
opcheck_int8_quant
(
out2
,
x
,
scale
)
tests/kernels/test_layernorm.py
View file @
4851c202
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -52,3 +53,10 @@ def test_rms_norm(
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
residual
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
fused_add_rms_norm
,
(
x
,
residual
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
tests/kernels/test_machete_gemm.py
View file @
4851c202
...
...
@@ -9,6 +9,7 @@ from typing import Optional, Tuple
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_rows
,
quantize_weights
)
...
...
@@ -76,6 +77,8 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
wtype
)
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
))
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
...
...
@@ -146,6 +149,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule
=
schedule
,
)
opcheck
(
torch
.
ops
.
_C
.
machete_gemm
,
(
a
,
w_q_machete
,
wtype
,
w_s
,
maybe_convert_zeropoints
(
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
...
...
tests/kernels/test_marlin_gemm.py
View file @
4851c202
...
...
@@ -5,6 +5,7 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import
pytest
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
...
...
@@ -73,12 +74,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order
,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
...
...
@@ -112,6 +110,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_repack
,
(
q_w_gptq
,
sort_indices
,
size_k
,
size_n
,
quant_type
.
size_bits
))
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
gptq_marlin_repack
(
q_w_gptq
,
...
...
@@ -137,12 +138,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
...
...
@@ -165,6 +163,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
opcheck
(
torch
.
ops
.
_C
.
awq_marlin_repack
,
(
q_w_awq
,
size_k
,
size_n
,
quant_type
.
size_bits
))
# Run Marlin repack GPU kernel
marlin_q_w_2
=
ops
.
awq_marlin_repack
(
q_w_awq
,
...
...
@@ -204,9 +205,6 @@ def test_gptq_marlin_gemm(
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
if
act_order
:
if
group_size
==
-
1
:
return
...
...
@@ -224,6 +222,13 @@ def test_gptq_marlin_gemm(
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
...
...
@@ -245,7 +250,6 @@ def test_gptq_marlin_gemm(
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
...
...
@@ -265,9 +269,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
...
...
@@ -279,6 +280,12 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
output_ref
=
torch
.
matmul
(
a_input
,
w_24_ref
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
,
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
workspace_24
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_marlin_24_gemm
(
a_input
,
marlin_24_q_w_comp
,
...
...
@@ -294,7 +301,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
...
...
@@ -321,9 +327,6 @@ def test_fp8_marlin_gemm(
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
),
dtype
=
dtype
)
b_weight
=
rand_data
((
size_k
,
size_n
),
dtype
=
dtype
)
...
...
@@ -353,6 +356,10 @@ def test_fp8_marlin_gemm(
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
opcheck
(
torch
.
ops
.
_C
.
fp8_marlin_gemm
,
(
a_input
,
marlin_qweight
,
marlin_scales
,
workspace
.
scratch
,
num_bits
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]))
output
=
ops
.
fp8_marlin_gemm
(
a
=
a_input
,
b_q_weight
=
marlin_qweight
,
...
...
@@ -368,7 +375,6 @@ def test_fp8_marlin_gemm(
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
...
...
@@ -396,9 +402,6 @@ def test_awq_marlin_gemm(
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
...
...
@@ -434,7 +437,6 @@ def test_awq_marlin_gemm(
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
...
...
@@ -460,9 +462,6 @@ def test_marlin_qqq_gemm(
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
...
...
@@ -479,6 +478,11 @@ def test_marlin_qqq_gemm(
workspace
=
MarlinWorkspace
(
size_n
,
MARLIN_QQQ_MIN_THREAD_N
,
MARLIN_QQQ_MAX_PARALLEL
)
opcheck
(
torch
.
ops
.
_C
.
marlin_qqq_gemm
,
(
q_a
,
marlin_qqq_q_w
,
s_a
,
marlin_qqq_s_channel
,
marlin_qqq_s_group
,
workspace
.
scratch
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]))
output
=
ops
.
marlin_qqq_gemm
(
q_a
,
marlin_qqq_q_w
,
...
...
@@ -495,6 +499,5 @@ def test_marlin_qqq_gemm(
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
tests/kernels/test_moe.py
View file @
4851c202
...
...
@@ -2,6 +2,8 @@
Run `pytest tests/kernels/test_moe.py`.
"""
from
typing
import
List
import
pytest
import
torch
from
transformers
import
MixtralConfig
...
...
@@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.scalar_type
import
scalar_types
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
...
...
@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
def
torch_moe_single
(
a
,
w
,
score
,
topk
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
_
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
a
[
mask
]
@
w
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w
.
shape
[
1
])).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1024
*
128
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
...
...
@@ -43,11 +65,11 @@ def test_fused_moe(
topk
:
int
,
dtype
:
torch
.
dtype
,
):
a
=
torch
.
randn
((
m
,
k
),
device
=
'
cuda
'
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
'
cuda
'
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
'
cuda
'
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"
cuda
"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"
cuda
"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"
cuda
"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
'
cuda
'
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"
cuda
"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
...
...
@@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
dev
=
tensors
[
0
].
device
return
torch
.
stack
(
tensors
,
dim
=
0
).
to
(
dev
)
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
def
test_fused_marlin_moe
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
torch
.
manual_seed
(
7
)
if
topk
>
e
:
return
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
in
(
k
,
n
):
return
quant_type
=
scalar_types
.
uint4b8
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
for
i
in
range
(
w2
.
shape
[
0
]):
w2
[
0
]
=
torch
.
eye
(
k
,
n
,
device
=
"cuda"
,
dtype
=
dtype
)
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
k
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1_l
.
append
(
w_ref1
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
g_idx1
=
stack_and_dev
(
g_idx1_l
)
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
n
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2_l
.
append
(
w_ref2
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
g_idx2
=
stack_and_dev
(
g_idx2_l
)
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
triton_output
=
fused_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
).
contiguous
(),
w_ref2
.
transpose
(
1
,
2
).
contiguous
(),
score
,
topk
,
renormalize
=
False
,
)
marlin_output
=
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
score
,
g_idx1
,
g_idx2
,
sort_indices1
,
sort_indices2
,
topk_weights
,
topk_ids
,
w1_scale
=
scales1
,
w2_scale
=
scales2
,
)
assert
compute_max_diff
(
marlin_output
,
triton_output
)
<
4e-2
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
def
test_marlin_moe_mmm
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
if
topk
>
e
:
return
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
==
k
:
return
quant_type
=
scalar_types
.
uint4b8
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref_l
=
[]
qweights_l
=
[]
scales_l
=
[]
g_idx_l
=
[]
sort_indices_l
=
[]
for
i
in
range
(
w
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
k
)
w_ref
,
qweight
,
scales
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref_l
.
append
(
w_ref
)
qweights_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
g_idx_l
.
append
(
g_idx
)
sort_indices_l
.
append
(
sort_indices
)
w_ref
=
stack_and_dev
(
w_ref_l
)
qweight
=
stack_and_dev
(
qweights_l
).
contiguous
()
scales
=
stack_and_dev
(
scales_l
)
g_idx
=
stack_and_dev
(
g_idx_l
)
sort_indices
=
stack_and_dev
(
sort_indices_l
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
single_marlin_moe
(
a
,
qweight
,
scales
,
score
,
g_idx
,
sort_indices
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe_single
(
a
,
w_ref
.
transpose
(
1
,
2
),
score
,
topk
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
1e-2
tests/kernels/utils.py
View file @
4851c202
...
...
@@ -3,7 +3,8 @@
import
itertools
import
random
from
numbers
import
Number
from
typing
import
Any
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
(
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
)
import
pytest
import
torch
...
...
@@ -13,6 +14,21 @@ from vllm.attention.backends.xformers import XFormersBackend
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_XFORMERS_ATTN_VAL
,
make_tensor_with_pad
)
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS
:
Tuple
[
str
,
...]
=
(
"test_schema"
,
"test_autograd_registration"
,
"test_faketensor"
,
)
ALL_OPCHECK_TEST_UTILS
:
Tuple
[
str
,
...]
=
(
"test_schema"
,
"test_autograd_registration"
,
"test_faketensor"
,
"test_aot_dispatch_dynamic"
,
)
class
QKVInputs
(
NamedTuple
):
'''
...
...
@@ -926,3 +942,19 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
torch
.
testing
.
assert_close
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
))
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
torch
.
_library
.
custom_ops
.
CustomOpDef
],
args
:
Tuple
[
Any
,
...],
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
*
,
test_utils
:
Union
[
str
,
Sequence
[
str
]]
=
ALL_OPCHECK_TEST_UTILS
,
raise_exception
:
bool
=
True
,
cond
:
bool
=
True
)
->
Dict
[
str
,
str
]:
return
torch
.
library
.
opcheck
(
op
,
args
,
kwargs
,
test_utils
=
test_utils
,
raise_exception
=
raise_exception
)
if
cond
else
{}
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment