Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
...@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple ...@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
import pytest import pytest
from transformers import AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ..conftest import DecoderPromptType from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close from ..models.utils import check_logprobs_close
...@@ -35,7 +35,7 @@ def vllm_to_hf_output( ...@@ -35,7 +35,7 @@ def vllm_to_hf_output(
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.skipif( @pytest.mark.skipif(
is_cpu(), current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models" reason="CPU backend is not currently supported with encoder/decoder models"
) )
def test_encoder_decoder_e2e( def test_encoder_decoder_e2e(
...@@ -50,7 +50,7 @@ def test_encoder_decoder_e2e( ...@@ -50,7 +50,7 @@ def test_encoder_decoder_e2e(
enforce_eager: bool, enforce_eager: bool,
) -> None: ) -> None:
''' '''
End-to-End (E2E) test for the encoder-decoder framework. End-to-End (E2E) test for the encoder-decoder framework.
This test evaluates the encoder-decoder functionality using the BART This test evaluates the encoder-decoder functionality using the BART
model. We compare the outputs of the Hugging Face and vLLM model. We compare the outputs of the Hugging Face and vLLM
implementations to ensure that both implementations produce consistent implementations to ensure that both implementations produce consistent
......
from typing import List
import pytest
from vllm import LLM
from ..openai.test_vision import TEST_IMAGE_URLS
def test_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
...@@ -4,8 +4,7 @@ from typing import List ...@@ -4,8 +4,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
MODEL_NAME = "intfloat/e5-mistral-7b-instruct" MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
...@@ -41,7 +40,7 @@ def llm(): ...@@ -41,7 +40,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput], def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
......
...@@ -4,9 +4,7 @@ from typing import List ...@@ -4,9 +4,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, RequestOutput, SamplingParams from vllm import LLM, RequestOutput, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
from ..openai.test_vision import TEST_IMAGE_URLS
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
...@@ -40,7 +38,7 @@ def llm(): ...@@ -40,7 +38,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
...@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM): ...@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied # sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None) outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
def test_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
...@@ -5,10 +5,9 @@ import pytest ...@@ -5,10 +5,9 @@ import pytest
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
PROMPTS = [ PROMPTS = [
...@@ -39,7 +38,7 @@ def llm(): ...@@ -39,7 +38,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
......
...@@ -5,12 +5,11 @@ import weakref ...@@ -5,12 +5,11 @@ import weakref
import jsonschema import jsonschema
import pytest import pytest
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
...@@ -23,7 +22,7 @@ def llm(): ...@@ -23,7 +22,7 @@ def llm():
with llm.deprecate_legacy_api(): with llm.deprecate_legacy_api():
yield weakref.proxy(llm) yield weakref.proxy(llm)
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
......
import pytest
from vllm import LLM
from ...utils import error_on_warning
MODEL_NAME = "facebook/opt-125m"
def test_pos_args_deprecated():
with error_on_warning(DeprecationWarning):
LLM(model=MODEL_NAME, tokenizer=MODEL_NAME)
with error_on_warning(DeprecationWarning):
LLM(MODEL_NAME, tokenizer=MODEL_NAME)
with pytest.warns(DeprecationWarning, match="'tokenizer'"):
LLM(MODEL_NAME, MODEL_NAME)
with pytest.warns(DeprecationWarning,
match="'tokenizer', 'tokenizer_mode'"):
LLM(MODEL_NAME, MODEL_NAME, "auto")
import sys import sys
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
def test_lazy_outlines(sample_regex): def test_lazy_outlines(sample_regex):
...@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex): ...@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
] ]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="facebook/opt-125m",
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.3) gpu_memory_utilization=0.3)
...@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex): ...@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported # make sure outlines is not imported
assert 'outlines' not in sys.modules assert 'outlines' not in sys.modules
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()
# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="facebook/opt-125m",
enforce_eager=True, enforce_eager=True,
guided_decoding_backend="lm-format-enforcer", guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3) gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate( outputs = llm.generate(
prompts=[ prompts=[
......
"""Tests for HF_HUB_OFFLINE mode""" """Tests for HF_HUB_OFFLINE mode"""
import importlib import importlib
import sys import sys
import weakref
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
MODEL_CONFIGS = [
MODEL_NAME = "facebook/opt-125m" {
"model": "facebook/opt-125m",
"enforce_eager": True,
"gpu_memory_utilization": 0.20,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
},
{
"model": "mistralai/Mistral-7B-Instruct-v0.1",
"enforce_eager": True,
"gpu_memory_utilization": 0.95,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
"tokenizer_mode": "mistral",
},
]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def cache_models():
# pytest caches the fixture so we use weakref.proxy to # Cache model files first
# enable garbage collection for model_config in MODEL_CONFIGS:
llm = LLM(model=MODEL_NAME, LLM(**model_config)
max_num_batched_tokens=4096, cleanup_dist_env_and_memory()
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm yield
cleanup()
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_offline_mode(llm: LLM, monkeypatch): @pytest.mark.usefixtures("cache_models")
# we use the llm fixture to ensure the model files are in-cache def test_offline_mode(monkeypatch):
del llm
# Set HF to offline mode and ensure we can still construct an LLM # Set HF to offline mode and ensure we can still construct an LLM
try: try:
monkeypatch.setenv("HF_HUB_OFFLINE", "1") monkeypatch.setenv("HF_HUB_OFFLINE", "1")
# Need to re-import huggingface_hub and friends to setup offline mode # Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules() _re_import_modules()
# Cached model files should be used in offline mode # Cached model files should be used in offline mode
LLM(model=MODEL_NAME, for model_config in MODEL_CONFIGS:
max_num_batched_tokens=4096, LLM(**model_config)
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
finally: finally:
# Reset the environment after the test # Reset the environment after the test
# NB: Assuming tests are run in online mode # NB: Assuming tests are run in online mode
......
...@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401 ...@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): ...@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI): async def test_response_format_json_schema(client: openai.AsyncOpenAI):
prompt = 'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for _ in range(2): for _ in range(2):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
"role": "role": "user",
"user", "content": prompt
"content": ('what is 1+1? please respond with a JSON object, ' }],
'the format is {"result": 2}') )
content = resp.choices[0].message.content
assert content is not None
with pytest.raises((json.JSONDecodeError, AssertionError)):
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": prompt
}], }],
response_format={ response_format={
"type": "json_schema", "type": "json_schema",
......
...@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, ...@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
......
...@@ -22,12 +22,12 @@ class MockHFConfig: ...@@ -22,12 +22,12 @@ class MockHFConfig:
@dataclass @dataclass
class MockModelConfig: class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME tokenizer = MODEL_NAME
trust_remote_code = False trust_remote_code = False
tokenizer_mode = "auto" tokenizer_mode = "auto"
max_model_len = 100 max_model_len = 100
tokenizer_revision = None tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "meta-llama/Llama-3.2-1B"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [ ...@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
"--task",
"generate",
"--dtype", "--dtype",
"bfloat16", "bfloat16",
"--max-model-len", "--max-model-len",
......
...@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" ...@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def phi3v_model_config(): def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID, return ModelConfig(PHI3V_MODEL_ID,
PHI3V_MODEL_ID, task="generate",
tokenizer=PHI3V_MODEL_ID,
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=True, trust_remote_code=True,
dtype="bfloat16", dtype="bfloat16",
...@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ...@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text": "What about these two?" "text": "What about these two?"
}] }]
}], phi3v_model_config, phi3v_tokenizer) }], phi3v_model_config, phi3v_tokenizer)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
...@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch): ...@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name) override_backend_env_variable(monkeypatch, name)
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True): with patch("vllm.attention.selector.current_platform.is_cpu",
backend = which_attn_to_use(16, None, torch.float16, torch.float16, return_value=True):
16, False) backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA" assert backend.name == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True): with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
16, False) False)
assert backend.name == "ROCM_FLASH" assert backend.name == "ROCM_FLASH"
elif device == "openvino": elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True): with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
16, False) False)
assert backend.name == "OPENVINO" assert backend.name == "OPENVINO"
else: else:
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == name assert backend.name == name
...@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch): ...@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)): with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False) backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type # Unsupported data type
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False) backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type # Unsupported kv cache data type
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False) backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size # Unsupported block size
backend = which_attn_to_use(16, None, torch.float16, None, 8, False) backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed # flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}): with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False) backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size # Unsupported head size
backend = which_attn_to_use(17, None, torch.float16, None, 16, False) backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention # Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
True)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
...@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch): ...@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid.""" """Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL) override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError): with pytest.raises(ValueError):
which_attn_to_use(16, None, torch.float16, None, 16, False) which_attn_to_use(16, torch.float16, None, 16, False)
\ No newline at end of file
...@@ -78,6 +78,7 @@ def ref_paged_attn( ...@@ -78,6 +78,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
kv_lens: List[int], kv_lens: List[int],
...@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv( ...@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
sliding_window: Optional[int],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
seed_everything(0) seed_everything(0)
...@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv( ...@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens) max_kv_len = max(kv_lens)
scale = head_size**-0.5 scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks, key_cache = torch.randn(num_blocks,
...@@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv( ...@@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
block_table=block_tables, block_table=block_tables,
cache_seqlens=kv_lens_tensor, cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1) ).squeeze(1)
ref_output = ref_paged_attn( ref_output = ref_paged_attn(query=query,
query=query, key_cache=key_cache,
key_cache=key_cache, value_cache=value_cache,
value_cache=value_cache, query_lens=[1] * num_seqs,
query_lens=[1] * num_seqs, kv_lens=kv_lens,
kv_lens=kv_lens, block_tables=block_tables,
block_tables=block_tables, scale=scale,
scale=scale, soft_cap=soft_cap,
soft_cap=soft_cap, sliding_window=sliding_window)
)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv( ...@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
...@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv( ...@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens) max_query_len = max(query_lens)
max_kv_len = max(kv_lens) max_kv_len = max(kv_lens)
window_size = ((sliding_window, window_size = ((sliding_window - 1, 0) if sliding_window is not None else
sliding_window) if sliding_window is not None else
(-1, -1)) (-1, -1))
scale = head_size**-0.5 scale = head_size**-0.5
......
...@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor, ...@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype) w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype)) opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
return w_ref, w_q_machete, w_s, w_zp return w_ref, w_q_machete, w_s, w_zp
...@@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype, ...@@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule, schedule=schedule,
) )
opcheck(torch.ops._C.machete_gemm, opcheck(
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints( torch.ops._C.machete_gemm,
w_zp, w_s), group_size, None, None, None, schedule)) (a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies # Relax atol when we have zeropoints since the way machete applies
......
...@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm( ...@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck( opcheck(
torch.ops._C.gptq_marlin_gemm, torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1], workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce), a_input.shape[1], is_k_full, False, use_fp32_reduce),
test_utils=DEFAULT_OPCHECK_TEST_UTILS) test_utils=DEFAULT_OPCHECK_TEST_UTILS)
...@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm( ...@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert max_diff < 0.04 assert max_diff < 0.04
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m, size_n,
size_k):
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m,
size_n, size_k)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
...@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, ...@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck(torch.ops._C.gptq_marlin_24_gemm, opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type, a_input.shape[0], workspace_24.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]), b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS) test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_24_gemm( output = marlin_24_gemm_tester(
a_input, a_input,
marlin_24_q_w_comp, marlin_24_q_w_comp,
marlin_24_meta, marlin_24_meta,
......
...@@ -240,8 +240,8 @@ def test_fused_marlin_moe( ...@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad=False) requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe, opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids, (a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
2 * n, k, True, e, topk, block_size_m, True, False)) m, 2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, " @pytest.mark.skip("This test is here for the sake of debugging, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment