Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2216a4e5
"tests/vscode:/vscode.git/clone" did not exist on "b6087a6beead9165f4c77ceba592b3651bb37de9"
Commit
2216a4e5
authored
Oct 23, 2024
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/main'
parents
ad385667
51c24c97
Changes
239
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
303 additions
and
185 deletions
+303
-185
tests/encoder_decoder/test_e2e_correctness.py
tests/encoder_decoder/test_e2e_correctness.py
+3
-3
tests/entrypoints/llm/test_chat.py
tests/entrypoints/llm/test_chat.py
+92
-0
tests/entrypoints/llm/test_encode.py
tests/entrypoints/llm/test_encode.py
+2
-3
tests/entrypoints/llm/test_generate.py
tests/entrypoints/llm/test_generate.py
+2
-91
tests/entrypoints/llm/test_generate_multiple_loras.py
tests/entrypoints/llm/test_generate_multiple_loras.py
+2
-3
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+2
-3
tests/entrypoints/llm/test_init.py
tests/entrypoints/llm/test_init.py
+22
-0
tests/entrypoints/llm/test_lazy_outlines.py
tests/entrypoints/llm/test_lazy_outlines.py
+8
-1
tests/entrypoints/offline_mode/test_offline_mode.py
tests/entrypoints/offline_mode/test_offline_mode.py
+33
-28
tests/entrypoints/openai/test_chat.py
tests/entrypoints/openai/test_chat.py
+18
-7
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+34
-0
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+1
-1
tests/entrypoints/openai/test_shutdown.py
tests/entrypoints/openai/test_shutdown.py
+1
-1
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+2
-0
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+28
-1
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+17
-21
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+16
-13
tests/kernels/test_machete_gemm.py
tests/kernels/test_machete_gemm.py
+5
-4
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+13
-3
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+2
-2
No files found.
tests/encoder_decoder/test_e2e_correctness.py
View file @
2216a4e5
...
...
@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
is_cpu
from
..conftest
import
DecoderPromptType
from
..models.utils
import
check_logprobs_close
...
...
@@ -35,7 +35,7 @@ def vllm_to_hf_output(
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
is_cpu
(),
current_platform
.
is_cpu
(),
reason
=
"CPU backend is not currently supported with encoder/decoder models"
)
def
test_encoder_decoder_e2e
(
...
...
@@ -50,7 +50,7 @@ def test_encoder_decoder_e2e(
enforce_eager
:
bool
,
)
->
None
:
'''
End-to-End (E2E) test for the encoder-decoder framework.
End-to-End (E2E) test for the encoder-decoder framework.
This test evaluates the encoder-decoder functionality using the BART
model. We compare the outputs of the Hugging Face and vLLM
implementations to ensure that both implementations produce consistent
...
...
tests/entrypoints/llm/test_chat.py
0 → 100644
View file @
2216a4e5
from
typing
import
List
import
pytest
from
vllm
import
LLM
from
..openai.test_vision
import
TEST_IMAGE_URLS
def
test_chat
():
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
def
test_multi_chat
():
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
prompt2
=
"Explain what among us is."
conversation1
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
conversation2
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt2
},
]
messages
=
[
conversation1
,
conversation2
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
2
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
test_chat_multi_image
(
image_urls
:
List
[
str
]):
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
dtype
=
"bfloat16"
,
max_model_len
=
4096
,
max_num_seqs
=
5
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
},
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[
*
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}
for
image_url
in
image_urls
),
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
>=
0
tests/entrypoints/llm/test_encode.py
View file @
2216a4e5
...
...
@@ -4,8 +4,7 @@ from typing import List
import
pytest
from
vllm
import
LLM
,
EmbeddingRequestOutput
,
PoolingParams
from
...conftest
import
cleanup
from
vllm.distributed
import
cleanup_dist_env_and_memory
MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
...
...
@@ -41,7 +40,7 @@ def llm():
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
def
assert_outputs_equal
(
o1
:
List
[
EmbeddingRequestOutput
],
...
...
tests/entrypoints/llm/test_generate.py
View file @
2216a4e5
...
...
@@ -4,9 +4,7 @@ from typing import List
import
pytest
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
...conftest
import
cleanup
from
..openai.test_vision
import
TEST_IMAGE_URLS
from
vllm.distributed
import
cleanup_dist_env_and_memory
MODEL_NAME
=
"facebook/opt-125m"
...
...
@@ -40,7 +38,7 @@ def llm():
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
def
assert_outputs_equal
(
o1
:
List
[
RequestOutput
],
o2
:
List
[
RequestOutput
]):
...
...
@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
None
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
def
test_chat
():
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
def
test_multi_chat
():
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
prompt2
=
"Explain what among us is."
conversation1
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
conversation2
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt2
},
]
messages
=
[
conversation1
,
conversation2
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
2
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
test_chat_multi_image
(
image_urls
:
List
[
str
]):
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
dtype
=
"bfloat16"
,
max_model_len
=
4096
,
max_num_seqs
=
5
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
},
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[
*
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}
for
image_url
in
image_urls
),
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
>=
0
tests/entrypoints/llm/test_generate_multiple_loras.py
View file @
2216a4e5
...
...
@@ -5,10 +5,9 @@ import pytest
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.lora.request
import
LoRARequest
from
...conftest
import
cleanup
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
PROMPTS
=
[
...
...
@@ -39,7 +38,7 @@ def llm():
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
tests/entrypoints/llm/test_guided_generate.py
View file @
2216a4e5
...
...
@@ -5,12 +5,11 @@ import weakref
import
jsonschema
import
pytest
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
...conftest
import
cleanup
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
...
...
@@ -23,7 +22,7 @@ def llm():
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
@
pytest
.
mark
.
skip_global_cleanup
...
...
tests/entrypoints/llm/test_init.py
0 → 100644
View file @
2216a4e5
import
pytest
from
vllm
import
LLM
from
...utils
import
error_on_warning
MODEL_NAME
=
"facebook/opt-125m"
def
test_pos_args_deprecated
():
with
error_on_warning
(
DeprecationWarning
):
LLM
(
model
=
MODEL_NAME
,
tokenizer
=
MODEL_NAME
)
with
error_on_warning
(
DeprecationWarning
):
LLM
(
MODEL_NAME
,
tokenizer
=
MODEL_NAME
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'tokenizer'"
):
LLM
(
MODEL_NAME
,
MODEL_NAME
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'tokenizer', 'tokenizer_mode'"
):
LLM
(
MODEL_NAME
,
MODEL_NAME
,
"auto"
)
tests/entrypoints/llm/test_lazy_outlines.py
View file @
2216a4e5
import
sys
from
vllm
import
LLM
,
SamplingParams
from
vllm.distributed
import
cleanup_dist_env_and_memory
def
test_lazy_outlines
(
sample_regex
):
...
...
@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
]
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create an LLM without guided decoding as a baseline.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.3
)
...
...
@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported
assert
'outlines'
not
in
sys
.
modules
# Destroy the LLM object and free up the GPU memory.
del
llm
cleanup_dist_env_and_memory
()
# Create an LLM with guided decoding enabled.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
guided_decoding_backend
=
"lm-format-enforcer"
,
gpu_memory_utilization
=
0.
3
)
gpu_memory_utilization
=
0.
6
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
prompts
=
[
...
...
tests/entrypoints/offline_mode/test_offline_mode.py
View file @
2216a4e5
"""Tests for HF_HUB_OFFLINE mode"""
import
importlib
import
sys
import
weakref
import
pytest
from
vllm
import
LLM
from
...conftest
import
cleanup
MODEL_NAME
=
"facebook/opt-125m"
from
vllm.distributed
import
cleanup_dist_env_and_memory
MODEL_CONFIGS
=
[
{
"model"
:
"facebook/opt-125m"
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.20
,
"max_model_len"
:
64
,
"max_num_batched_tokens"
:
64
,
"max_num_seqs"
:
64
,
"tensor_parallel_size"
:
1
,
},
{
"model"
:
"mistralai/Mistral-7B-Instruct-v0.1"
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.95
,
"max_model_len"
:
64
,
"max_num_batched_tokens"
:
64
,
"max_num_seqs"
:
64
,
"tensor_parallel_size"
:
1
,
"tokenizer_mode"
:
"mistral"
,
},
]
@
pytest
.
fixture
(
scope
=
"module"
)
def
llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
MODEL_NAME
,
max_num_batched_tokens
=
4096
,
tensor_parallel_size
=
1
,
gpu_memory_utilization
=
0.10
,
enforce_eager
=
True
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
def
cache_models
():
# Cache model files first
for
model_config
in
MODEL_CONFIGS
:
LLM
(
**
model_config
)
cleanup_dist_env_and_memory
()
del
llm
cleanup
()
yield
@
pytest
.
mark
.
skip_global_cleanup
def
test_offline_mode
(
llm
:
LLM
,
monkeypatch
):
# we use the llm fixture to ensure the model files are in-cache
del
llm
@
pytest
.
mark
.
usefixtures
(
"cache_models"
)
def
test_offline_mode
(
monkeypatch
):
# Set HF to offline mode and ensure we can still construct an LLM
try
:
monkeypatch
.
setenv
(
"HF_HUB_OFFLINE"
,
"1"
)
# Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules
()
# Cached model files should be used in offline mode
LLM
(
model
=
MODEL_NAME
,
max_num_batched_tokens
=
4096
,
tensor_parallel_size
=
1
,
gpu_memory_utilization
=
0.10
,
enforce_eager
=
True
)
for
model_config
in
MODEL_CONFIGS
:
LLM
(
**
model_config
)
finally
:
# Reset the environment after the test
# NB: Assuming tests are run in online mode
...
...
tests/entrypoints/openai/test_chat.py
View file @
2216a4e5
...
...
@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@
pytest
.
mark
.
asyncio
async
def
test_response_format_json_schema
(
client
:
openai
.
AsyncOpenAI
):
prompt
=
'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for
_
in
range
(
2
):
resp
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
(
'what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}'
)
"role"
:
"user"
,
"content"
:
prompt
}],
)
content
=
resp
.
choices
[
0
].
message
.
content
assert
content
is
not
None
with
pytest
.
raises
((
json
.
JSONDecodeError
,
AssertionError
)):
loaded
=
json
.
loads
(
content
)
assert
loaded
==
{
"result"
:
2
},
loaded
for
_
in
range
(
2
):
resp
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt
}],
response_format
=
{
"type"
:
"json_schema"
,
...
...
tests/entrypoints/openai/test_completion.py
View file @
2216a4e5
...
...
@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert
""
.
join
(
chunks
)
==
single_output
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-pa"
],
)
async
def
test_parallel_streaming
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt
=
"What is an LLM?"
n
=
3
max_tokens
=
5
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
n
=
n
,
stream
=
True
)
chunks
:
List
[
List
[
str
]]
=
[[]
for
i
in
range
(
n
)]
finish_reason_count
=
0
async
for
chunk
in
stream
:
index
=
chunk
.
choices
[
0
].
index
text
=
chunk
.
choices
[
0
].
text
chunks
[
index
].
append
(
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
assert
finish_reason_count
==
n
for
chunk
in
chunks
:
assert
len
(
chunk
)
==
max_tokens
print
(
""
.
join
(
chunk
))
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
...
...
tests/entrypoints/openai/test_serving_chat.py
View file @
2216a4e5
...
...
@@ -22,12 +22,12 @@ class MockHFConfig:
@
dataclass
class
MockModelConfig
:
task
=
"generate"
tokenizer
=
MODEL_NAME
trust_remote_code
=
False
tokenizer_mode
=
"auto"
max_model_len
=
100
tokenizer_revision
=
None
embedding_mode
=
False
multimodal_config
=
MultiModalConfig
()
hf_config
=
MockHFConfig
()
...
...
tests/entrypoints/openai/test_shutdown.py
View file @
2216a4e5
...
...
@@ -6,7 +6,7 @@ import pytest
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"
HuggingFaceH4/zephyr-7b-beta
"
MODEL_NAME
=
"
meta-llama/Llama-3.2-1B
"
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/openai/test_vision.py
View file @
2216a4e5
...
...
@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--task"
,
"generate"
,
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
...
...
tests/entrypoints/test_chat_utils.py
View file @
2216a4e5
...
...
@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@
pytest
.
fixture
(
scope
=
"module"
)
def
phi3v_model_config
():
return
ModelConfig
(
PHI3V_MODEL_ID
,
PHI3V_MODEL_ID
,
task
=
"generate"
,
tokenizer
=
PHI3V_MODEL_ID
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
...
...
@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text"
:
"What about these two?"
}]
}],
phi3v_model_config
,
phi3v_tokenizer
)
def
test_parse_chat_messages_multiple_images_uncommon_input
(
phi3v_model_config
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
([{
"role"
:
"user"
,
"content"
:
[
"What's in these images?"
,
{
"image_url"
:
image_url
},
{
"image_url"
:
image_url
}
]
}],
phi3v_model_config
,
phi3v_tokenizer
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
tests/kernels/test_attention_selector.py
View file @
2216a4e5
...
...
@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable
(
monkeypatch
,
name
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.is_cpu"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
with
patch
(
"vllm.attention.selector.current_platform.is_cpu"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.is_hip"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"OPENVINO"
else
:
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
name
...
...
@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported data type
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
"fp8"
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
"fp8"
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported block size
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend
=
which_attn_to_use
(
16
,
1
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported head size
backend
=
which_attn_to_use
(
17
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
...
...
@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
pytest
.
raises
(
ValueError
):
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
\ No newline at end of file
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
tests/kernels/test_flash_attn.py
View file @
2216a4e5
...
...
@@ -78,6 +78,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
...
...
@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
seed_everything
(
0
)
...
...
@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
...
...
@@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
).
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
...
@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
...
...
@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert
num_query_heads
%
num_kv_heads
==
0
max_query_len
=
max
(
query_lens
)
max_kv_len
=
max
(
kv_lens
)
window_size
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
scale
=
head_size
**-
0.5
...
...
tests/kernels/test_machete_gemm.py
View file @
2216a4e5
...
...
@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
wtype
)
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
))
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
.
id
))
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
...
...
@@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule
=
schedule
,
)
opcheck
(
torch
.
ops
.
_C
.
machete_gemm
,
(
a
,
w_q_machete
,
wtype
,
w_s
,
maybe_convert_zeropoints
(
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
opcheck
(
torch
.
ops
.
_C
.
machete_gemm
,
(
a
,
w_q_machete
,
wtype
.
id
,
w_s
,
maybe_convert_zeropoints
(
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
...
...
tests/kernels/test_marlin_gemm.py
View file @
2216a4e5
...
...
@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
...
...
@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert
max_diff
<
0.04
# TODO: find better way to test this?
@
torch
.
compile
(
fullgraph
=
True
)
def
marlin_24_gemm_tester
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
scratch
,
quant_type
,
size_m
,
size_n
,
size_k
):
return
ops
.
gptq_marlin_24_gemm
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
scratch
,
quant_type
,
size_m
,
size_n
,
size_k
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
...
...
@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
,
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
workspace_24
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
workspace_24
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_
marlin_24_gemm
(
output
=
marlin_24_gemm
_tester
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
...
...
tests/kernels/test_moe.py
View file @
2216a4e5
...
...
@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad
=
False
)
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
zp
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
scales1
,
zp
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
.
id
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
...
...
Prev
1
2
3
4
5
6
7
8
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment