Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2216a4e5
Commit
2216a4e5
authored
Oct 23, 2024
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/main'
parents
ad385667
51c24c97
Changes
239
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
303 additions
and
185 deletions
+303
-185
tests/encoder_decoder/test_e2e_correctness.py
tests/encoder_decoder/test_e2e_correctness.py
+3
-3
tests/entrypoints/llm/test_chat.py
tests/entrypoints/llm/test_chat.py
+92
-0
tests/entrypoints/llm/test_encode.py
tests/entrypoints/llm/test_encode.py
+2
-3
tests/entrypoints/llm/test_generate.py
tests/entrypoints/llm/test_generate.py
+2
-91
tests/entrypoints/llm/test_generate_multiple_loras.py
tests/entrypoints/llm/test_generate_multiple_loras.py
+2
-3
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+2
-3
tests/entrypoints/llm/test_init.py
tests/entrypoints/llm/test_init.py
+22
-0
tests/entrypoints/llm/test_lazy_outlines.py
tests/entrypoints/llm/test_lazy_outlines.py
+8
-1
tests/entrypoints/offline_mode/test_offline_mode.py
tests/entrypoints/offline_mode/test_offline_mode.py
+33
-28
tests/entrypoints/openai/test_chat.py
tests/entrypoints/openai/test_chat.py
+18
-7
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+34
-0
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+1
-1
tests/entrypoints/openai/test_shutdown.py
tests/entrypoints/openai/test_shutdown.py
+1
-1
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+2
-0
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+28
-1
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+17
-21
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+16
-13
tests/kernels/test_machete_gemm.py
tests/kernels/test_machete_gemm.py
+5
-4
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+13
-3
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+2
-2
No files found.
tests/encoder_decoder/test_e2e_correctness.py
View file @
2216a4e5
...
@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
...
@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
import
pytest
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
is_cpu
from
..conftest
import
DecoderPromptType
from
..conftest
import
DecoderPromptType
from
..models.utils
import
check_logprobs_close
from
..models.utils
import
check_logprobs_close
...
@@ -35,7 +35,7 @@ def vllm_to_hf_output(
...
@@ -35,7 +35,7 @@ def vllm_to_hf_output(
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
is_cpu
(),
current_platform
.
is_cpu
(),
reason
=
"CPU backend is not currently supported with encoder/decoder models"
reason
=
"CPU backend is not currently supported with encoder/decoder models"
)
)
def
test_encoder_decoder_e2e
(
def
test_encoder_decoder_e2e
(
...
@@ -50,7 +50,7 @@ def test_encoder_decoder_e2e(
...
@@ -50,7 +50,7 @@ def test_encoder_decoder_e2e(
enforce_eager
:
bool
,
enforce_eager
:
bool
,
)
->
None
:
)
->
None
:
'''
'''
End-to-End (E2E) test for the encoder-decoder framework.
End-to-End (E2E) test for the encoder-decoder framework.
This test evaluates the encoder-decoder functionality using the BART
This test evaluates the encoder-decoder functionality using the BART
model. We compare the outputs of the Hugging Face and vLLM
model. We compare the outputs of the Hugging Face and vLLM
implementations to ensure that both implementations produce consistent
implementations to ensure that both implementations produce consistent
...
...
tests/entrypoints/llm/test_chat.py
0 → 100644
View file @
2216a4e5
from
typing
import
List
import
pytest
from
vllm
import
LLM
from
..openai.test_vision
import
TEST_IMAGE_URLS
def
test_chat
():
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
def
test_multi_chat
():
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
prompt2
=
"Explain what among us is."
conversation1
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
conversation2
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt2
},
]
messages
=
[
conversation1
,
conversation2
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
2
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
test_chat_multi_image
(
image_urls
:
List
[
str
]):
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
dtype
=
"bfloat16"
,
max_model_len
=
4096
,
max_num_seqs
=
5
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
},
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[
*
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}
for
image_url
in
image_urls
),
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
>=
0
tests/entrypoints/llm/test_encode.py
View file @
2216a4e5
...
@@ -4,8 +4,7 @@ from typing import List
...
@@ -4,8 +4,7 @@ from typing import List
import
pytest
import
pytest
from
vllm
import
LLM
,
EmbeddingRequestOutput
,
PoolingParams
from
vllm
import
LLM
,
EmbeddingRequestOutput
,
PoolingParams
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
...conftest
import
cleanup
MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
...
@@ -41,7 +40,7 @@ def llm():
...
@@ -41,7 +40,7 @@ def llm():
del
llm
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
def
assert_outputs_equal
(
o1
:
List
[
EmbeddingRequestOutput
],
def
assert_outputs_equal
(
o1
:
List
[
EmbeddingRequestOutput
],
...
...
tests/entrypoints/llm/test_generate.py
View file @
2216a4e5
...
@@ -4,9 +4,7 @@ from typing import List
...
@@ -4,9 +4,7 @@ from typing import List
import
pytest
import
pytest
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
...conftest
import
cleanup
from
..openai.test_vision
import
TEST_IMAGE_URLS
MODEL_NAME
=
"facebook/opt-125m"
MODEL_NAME
=
"facebook/opt-125m"
...
@@ -40,7 +38,7 @@ def llm():
...
@@ -40,7 +38,7 @@ def llm():
del
llm
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
def
assert_outputs_equal
(
o1
:
List
[
RequestOutput
],
o2
:
List
[
RequestOutput
]):
def
assert_outputs_equal
(
o1
:
List
[
RequestOutput
],
o2
:
List
[
RequestOutput
]):
...
@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM):
...
@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
# sampling_params is None, default params should be applied
outputs
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
None
)
outputs
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
None
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
def
test_chat
():
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
def
test_multi_chat
():
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
prompt2
=
"Explain what among us is."
conversation1
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt1
},
]
conversation2
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
prompt2
},
]
messages
=
[
conversation1
,
conversation2
]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
2
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
test_chat_multi_image
(
image_urls
:
List
[
str
]):
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
dtype
=
"bfloat16"
,
max_model_len
=
4096
,
max_num_seqs
=
5
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
},
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[
*
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}
for
image_url
in
image_urls
),
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
>=
0
tests/entrypoints/llm/test_generate_multiple_loras.py
View file @
2216a4e5
...
@@ -5,10 +5,9 @@ import pytest
...
@@ -5,10 +5,9 @@ import pytest
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
...conftest
import
cleanup
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
PROMPTS
=
[
PROMPTS
=
[
...
@@ -39,7 +38,7 @@ def llm():
...
@@ -39,7 +38,7 @@ def llm():
del
llm
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
tests/entrypoints/llm/test_guided_generate.py
View file @
2216a4e5
...
@@ -5,12 +5,11 @@ import weakref
...
@@ -5,12 +5,11 @@ import weakref
import
jsonschema
import
jsonschema
import
pytest
import
pytest
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
...conftest
import
cleanup
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
...
@@ -23,7 +22,7 @@ def llm():
...
@@ -23,7 +22,7 @@ def llm():
with
llm
.
deprecate_legacy_api
():
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
yield
weakref
.
proxy
(
llm
)
del
llm
del
llm
cleanup
()
cleanup
_dist_env_and_memory
()
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
...
...
tests/entrypoints/llm/test_init.py
0 → 100644
View file @
2216a4e5
import
pytest
from
vllm
import
LLM
from
...utils
import
error_on_warning
MODEL_NAME
=
"facebook/opt-125m"
def
test_pos_args_deprecated
():
with
error_on_warning
(
DeprecationWarning
):
LLM
(
model
=
MODEL_NAME
,
tokenizer
=
MODEL_NAME
)
with
error_on_warning
(
DeprecationWarning
):
LLM
(
MODEL_NAME
,
tokenizer
=
MODEL_NAME
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'tokenizer'"
):
LLM
(
MODEL_NAME
,
MODEL_NAME
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'tokenizer', 'tokenizer_mode'"
):
LLM
(
MODEL_NAME
,
MODEL_NAME
,
"auto"
)
tests/entrypoints/llm/test_lazy_outlines.py
View file @
2216a4e5
import
sys
import
sys
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.distributed
import
cleanup_dist_env_and_memory
def
test_lazy_outlines
(
sample_regex
):
def
test_lazy_outlines
(
sample_regex
):
...
@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
...
@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
]
]
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create an LLM without guided decoding as a baseline.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.3
)
gpu_memory_utilization
=
0.3
)
...
@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex):
...
@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported
# make sure outlines is not imported
assert
'outlines'
not
in
sys
.
modules
assert
'outlines'
not
in
sys
.
modules
# Destroy the LLM object and free up the GPU memory.
del
llm
cleanup_dist_env_and_memory
()
# Create an LLM with guided decoding enabled.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
guided_decoding_backend
=
"lm-format-enforcer"
,
guided_decoding_backend
=
"lm-format-enforcer"
,
gpu_memory_utilization
=
0.
3
)
gpu_memory_utilization
=
0.
6
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
[
prompts
=
[
...
...
tests/entrypoints/offline_mode/test_offline_mode.py
View file @
2216a4e5
"""Tests for HF_HUB_OFFLINE mode"""
"""Tests for HF_HUB_OFFLINE mode"""
import
importlib
import
importlib
import
sys
import
sys
import
weakref
import
pytest
import
pytest
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
...conftest
import
cleanup
MODEL_CONFIGS
=
[
MODEL_NAME
=
"facebook/opt-125m"
{
"model"
:
"facebook/opt-125m"
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.20
,
"max_model_len"
:
64
,
"max_num_batched_tokens"
:
64
,
"max_num_seqs"
:
64
,
"tensor_parallel_size"
:
1
,
},
{
"model"
:
"mistralai/Mistral-7B-Instruct-v0.1"
,
"enforce_eager"
:
True
,
"gpu_memory_utilization"
:
0.95
,
"max_model_len"
:
64
,
"max_num_batched_tokens"
:
64
,
"max_num_seqs"
:
64
,
"tensor_parallel_size"
:
1
,
"tokenizer_mode"
:
"mistral"
,
},
]
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
llm
():
def
cache_models
():
# pytest caches the fixture so we use weakref.proxy to
# Cache model files first
# enable garbage collection
for
model_config
in
MODEL_CONFIGS
:
llm
=
LLM
(
model
=
MODEL_NAME
,
LLM
(
**
model_config
)
max_num_batched_tokens
=
4096
,
cleanup_dist_env_and_memory
()
tensor_parallel_size
=
1
,
gpu_memory_utilization
=
0.10
,
enforce_eager
=
True
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
yield
cleanup
()
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
def
test_offline_mode
(
llm
:
LLM
,
monkeypatch
):
@
pytest
.
mark
.
usefixtures
(
"cache_models"
)
# we use the llm fixture to ensure the model files are in-cache
def
test_offline_mode
(
monkeypatch
):
del
llm
# Set HF to offline mode and ensure we can still construct an LLM
# Set HF to offline mode and ensure we can still construct an LLM
try
:
try
:
monkeypatch
.
setenv
(
"HF_HUB_OFFLINE"
,
"1"
)
monkeypatch
.
setenv
(
"HF_HUB_OFFLINE"
,
"1"
)
# Need to re-import huggingface_hub and friends to setup offline mode
# Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules
()
_re_import_modules
()
# Cached model files should be used in offline mode
# Cached model files should be used in offline mode
LLM
(
model
=
MODEL_NAME
,
for
model_config
in
MODEL_CONFIGS
:
max_num_batched_tokens
=
4096
,
LLM
(
**
model_config
)
tensor_parallel_size
=
1
,
gpu_memory_utilization
=
0.10
,
enforce_eager
=
True
)
finally
:
finally
:
# Reset the environment after the test
# Reset the environment after the test
# NB: Assuming tests are run in online mode
# NB: Assuming tests are run in online mode
...
...
tests/entrypoints/openai/test_chat.py
View file @
2216a4e5
...
@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
...
@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
...
@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_response_format_json_schema
(
client
:
openai
.
AsyncOpenAI
):
async
def
test_response_format_json_schema
(
client
:
openai
.
AsyncOpenAI
):
prompt
=
'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for
_
in
range
(
2
):
for
_
in
range
(
2
):
resp
=
await
client
.
chat
.
completions
.
create
(
resp
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
[{
messages
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
prompt
"content"
:
(
'what is 1+1? please respond with a JSON object, '
}],
'the format is {"result": 2}'
)
)
content
=
resp
.
choices
[
0
].
message
.
content
assert
content
is
not
None
with
pytest
.
raises
((
json
.
JSONDecodeError
,
AssertionError
)):
loaded
=
json
.
loads
(
content
)
assert
loaded
==
{
"result"
:
2
},
loaded
for
_
in
range
(
2
):
resp
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt
}],
}],
response_format
=
{
response_format
=
{
"type"
:
"json_schema"
,
"type"
:
"json_schema"
,
...
...
tests/entrypoints/openai/test_completion.py
View file @
2216a4e5
...
@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
...
@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert
""
.
join
(
chunks
)
==
single_output
assert
""
.
join
(
chunks
)
==
single_output
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-pa"
],
)
async
def
test_parallel_streaming
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt
=
"What is an LLM?"
n
=
3
max_tokens
=
5
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
n
=
n
,
stream
=
True
)
chunks
:
List
[
List
[
str
]]
=
[[]
for
i
in
range
(
n
)]
finish_reason_count
=
0
async
for
chunk
in
stream
:
index
=
chunk
.
choices
[
0
].
index
text
=
chunk
.
choices
[
0
].
text
chunks
[
index
].
append
(
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
assert
finish_reason_count
==
n
for
chunk
in
chunks
:
assert
len
(
chunk
)
==
max_tokens
print
(
""
.
join
(
chunk
))
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_name"
,
"model_name"
,
...
...
tests/entrypoints/openai/test_serving_chat.py
View file @
2216a4e5
...
@@ -22,12 +22,12 @@ class MockHFConfig:
...
@@ -22,12 +22,12 @@ class MockHFConfig:
@
dataclass
@
dataclass
class
MockModelConfig
:
class
MockModelConfig
:
task
=
"generate"
tokenizer
=
MODEL_NAME
tokenizer
=
MODEL_NAME
trust_remote_code
=
False
trust_remote_code
=
False
tokenizer_mode
=
"auto"
tokenizer_mode
=
"auto"
max_model_len
=
100
max_model_len
=
100
tokenizer_revision
=
None
tokenizer_revision
=
None
embedding_mode
=
False
multimodal_config
=
MultiModalConfig
()
multimodal_config
=
MultiModalConfig
()
hf_config
=
MockHFConfig
()
hf_config
=
MockHFConfig
()
...
...
tests/entrypoints/openai/test_shutdown.py
View file @
2216a4e5
...
@@ -6,7 +6,7 @@ import pytest
...
@@ -6,7 +6,7 @@ import pytest
from
...utils
import
RemoteOpenAIServer
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"
HuggingFaceH4/zephyr-7b-beta
"
MODEL_NAME
=
"
meta-llama/Llama-3.2-1B
"
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/openai/test_vision.py
View file @
2216a4e5
...
@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
...
@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
def
server
():
args
=
[
args
=
[
"--task"
,
"generate"
,
"--dtype"
,
"--dtype"
,
"bfloat16"
,
"bfloat16"
,
"--max-model-len"
,
"--max-model-len"
,
...
...
tests/entrypoints/test_chat_utils.py
View file @
2216a4e5
...
@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
...
@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
phi3v_model_config
():
def
phi3v_model_config
():
return
ModelConfig
(
PHI3V_MODEL_ID
,
return
ModelConfig
(
PHI3V_MODEL_ID
,
PHI3V_MODEL_ID
,
task
=
"generate"
,
tokenizer
=
PHI3V_MODEL_ID
,
tokenizer_mode
=
"auto"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
...
@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
...
@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text"
:
"What about these two?"
"text"
:
"What about these two?"
}]
}]
}],
phi3v_model_config
,
phi3v_tokenizer
)
}],
phi3v_model_config
,
phi3v_tokenizer
)
def
test_parse_chat_messages_multiple_images_uncommon_input
(
phi3v_model_config
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
([{
"role"
:
"user"
,
"content"
:
[
"What's in these images?"
,
{
"image_url"
:
image_url
},
{
"image_url"
:
image_url
}
]
}],
phi3v_model_config
,
phi3v_tokenizer
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
tests/kernels/test_attention_selector.py
View file @
2216a4e5
...
@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch):
...
@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable
(
monkeypatch
,
name
)
override_backend_env_variable
(
monkeypatch
,
name
)
if
device
==
"cpu"
:
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.is_cpu"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.current_platform.is_cpu"
,
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
return_value
=
True
):
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"TORCH_SDPA"
assert
backend
.
name
==
"TORCH_SDPA"
elif
device
==
"hip"
:
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.is_hip"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.is_hip"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
16
,
False
)
False
)
assert
backend
.
name
==
"ROCM_FLASH"
assert
backend
.
name
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
16
,
False
)
False
)
assert
backend
.
name
==
"OPENVINO"
assert
backend
.
name
==
"OPENVINO"
else
:
else
:
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
False
)
assert
backend
.
name
==
name
assert
backend
.
name
==
name
...
@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch):
...
@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported data type
# Unsupported data type
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
# Unsupported kv cache data type
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
"fp8"
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
"fp8"
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported block size
# Unsupported block size
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
8
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend
=
which_attn_to_use
(
16
,
1
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported head size
# Unsupported head size
backend
=
which_attn_to_use
(
17
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
True
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
...
@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch):
...
@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
\ No newline at end of file
tests/kernels/test_flash_attn.py
View file @
2216a4e5
...
@@ -78,6 +78,7 @@ def ref_paged_attn(
...
@@ -78,6 +78,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
...
@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
seed_everything
(
0
)
seed_everything
(
0
)
...
@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
...
@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert
num_query_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
key_cache
=
torch
.
randn
(
num_blocks
,
...
@@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
...
@@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
block_table
=
block_tables
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
).
squeeze
(
1
)
).
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
query
=
query
,
query
=
query
,
key_cache
=
key_cache
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
,
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
...
@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
...
@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert
num_query_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
max_kv_len
=
max
(
kv_lens
)
max_kv_len
=
max
(
kv_lens
)
window_size
=
((
sliding_window
,
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
(
-
1
,
-
1
))
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
...
...
tests/kernels/test_machete_gemm.py
View file @
2216a4e5
...
@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
...
@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q
=
w_q
.
t
().
contiguous
().
t
()
# convert to col major
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
wtype
)
w_q_machete
=
ops
.
machete_prepack_B
(
w_q
,
wtype
)
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
))
opcheck
(
torch
.
ops
.
_C
.
machete_prepack_B
,
(
w_q
,
wtype
.
id
))
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
...
@@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
...
@@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule
=
schedule
,
schedule
=
schedule
,
)
)
opcheck
(
torch
.
ops
.
_C
.
machete_gemm
,
opcheck
(
(
a
,
w_q_machete
,
wtype
,
w_s
,
maybe_convert_zeropoints
(
torch
.
ops
.
_C
.
machete_gemm
,
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
(
a
,
w_q_machete
,
wtype
.
id
,
w_s
,
maybe_convert_zeropoints
(
w_zp
,
w_s
),
group_size
,
None
,
None
,
None
,
schedule
))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# Relax atol when we have zeropoints since the way machete applies
...
...
tests/kernels/test_marlin_gemm.py
View file @
2216a4e5
...
@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
...
@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck
(
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
),
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
...
@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
...
@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert
max_diff
<
0.04
assert
max_diff
<
0.04
# TODO: find better way to test this?
@
torch
.
compile
(
fullgraph
=
True
)
def
marlin_24_gemm_tester
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
scratch
,
quant_type
,
size_m
,
size_n
,
size_k
):
return
ops
.
gptq_marlin_24_gemm
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
scratch
,
quant_type
,
size_m
,
size_n
,
size_k
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_24_K_CHUNKS
)
...
@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
...
@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
,
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
,
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
(
a_input
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
,
workspace_24
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
workspace_24
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]),
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_
marlin_24_gemm
(
output
=
marlin_24_gemm
_tester
(
a_input
,
a_input
,
marlin_24_q_w_comp
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_meta
,
...
...
tests/kernels/test_moe.py
View file @
2216a4e5
...
@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
...
@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad
=
False
)
requires_grad
=
False
)
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
zp
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
scales1
,
zp
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
.
id
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
...
...
Prev
1
2
3
4
5
6
7
8
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment