Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
...@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple ...@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
import pytest import pytest
from transformers import AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ..conftest import DecoderPromptType from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close from ..models.utils import check_logprobs_close
...@@ -35,7 +35,7 @@ def vllm_to_hf_output( ...@@ -35,7 +35,7 @@ def vllm_to_hf_output(
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.skipif( @pytest.mark.skipif(
is_cpu(), current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models" reason="CPU backend is not currently supported with encoder/decoder models"
) )
def test_encoder_decoder_e2e( def test_encoder_decoder_e2e(
......
from typing import List
import pytest
from vllm import LLM
from ..openai.test_vision import TEST_IMAGE_URLS
def test_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
...@@ -4,8 +4,7 @@ from typing import List ...@@ -4,8 +4,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
MODEL_NAME = "intfloat/e5-mistral-7b-instruct" MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
...@@ -41,7 +40,7 @@ def llm(): ...@@ -41,7 +40,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput], def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
......
...@@ -4,9 +4,7 @@ from typing import List ...@@ -4,9 +4,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, RequestOutput, SamplingParams from vllm import LLM, RequestOutput, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
from ..openai.test_vision import TEST_IMAGE_URLS
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
...@@ -40,7 +38,7 @@ def llm(): ...@@ -40,7 +38,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
...@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM): ...@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied # sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None) outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
def test_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
...@@ -5,10 +5,9 @@ import pytest ...@@ -5,10 +5,9 @@ import pytest
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
PROMPTS = [ PROMPTS = [
...@@ -39,7 +38,7 @@ def llm(): ...@@ -39,7 +38,7 @@ def llm():
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
......
...@@ -5,12 +5,11 @@ import weakref ...@@ -5,12 +5,11 @@ import weakref
import jsonschema import jsonschema
import pytest import pytest
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
...@@ -23,7 +22,7 @@ def llm(): ...@@ -23,7 +22,7 @@ def llm():
with llm.deprecate_legacy_api(): with llm.deprecate_legacy_api():
yield weakref.proxy(llm) yield weakref.proxy(llm)
del llm del llm
cleanup() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
......
import pytest
from vllm import LLM
from ...utils import error_on_warning
MODEL_NAME = "facebook/opt-125m"
def test_pos_args_deprecated():
with error_on_warning(DeprecationWarning):
LLM(model=MODEL_NAME, tokenizer=MODEL_NAME)
with error_on_warning(DeprecationWarning):
LLM(MODEL_NAME, tokenizer=MODEL_NAME)
with pytest.warns(DeprecationWarning, match="'tokenizer'"):
LLM(MODEL_NAME, MODEL_NAME)
with pytest.warns(DeprecationWarning,
match="'tokenizer', 'tokenizer_mode'"):
LLM(MODEL_NAME, MODEL_NAME, "auto")
import sys import sys
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
def test_lazy_outlines(sample_regex): def test_lazy_outlines(sample_regex):
...@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex): ...@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
] ]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="facebook/opt-125m",
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.3) gpu_memory_utilization=0.3)
...@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex): ...@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported # make sure outlines is not imported
assert 'outlines' not in sys.modules assert 'outlines' not in sys.modules
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()
# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="facebook/opt-125m",
enforce_eager=True, enforce_eager=True,
guided_decoding_backend="lm-format-enforcer", guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3) gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate( outputs = llm.generate(
prompts=[ prompts=[
......
"""Tests for HF_HUB_OFFLINE mode""" """Tests for HF_HUB_OFFLINE mode"""
import importlib import importlib
import sys import sys
import weakref
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import cleanup
MODEL_CONFIGS = [
MODEL_NAME = "facebook/opt-125m" {
"model": "facebook/opt-125m",
"enforce_eager": True,
"gpu_memory_utilization": 0.20,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
},
{
"model": "mistralai/Mistral-7B-Instruct-v0.1",
"enforce_eager": True,
"gpu_memory_utilization": 0.95,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
"tokenizer_mode": "mistral",
},
]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def cache_models():
# pytest caches the fixture so we use weakref.proxy to # Cache model files first
# enable garbage collection for model_config in MODEL_CONFIGS:
llm = LLM(model=MODEL_NAME, LLM(**model_config)
max_num_batched_tokens=4096, cleanup_dist_env_and_memory()
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm yield
cleanup()
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_offline_mode(llm: LLM, monkeypatch): @pytest.mark.usefixtures("cache_models")
# we use the llm fixture to ensure the model files are in-cache def test_offline_mode(monkeypatch):
del llm
# Set HF to offline mode and ensure we can still construct an LLM # Set HF to offline mode and ensure we can still construct an LLM
try: try:
monkeypatch.setenv("HF_HUB_OFFLINE", "1") monkeypatch.setenv("HF_HUB_OFFLINE", "1")
# Need to re-import huggingface_hub and friends to setup offline mode # Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules() _re_import_modules()
# Cached model files should be used in offline mode # Cached model files should be used in offline mode
LLM(model=MODEL_NAME, for model_config in MODEL_CONFIGS:
max_num_batched_tokens=4096, LLM(**model_config)
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
finally: finally:
# Reset the environment after the test # Reset the environment after the test
# NB: Assuming tests are run in online mode # NB: Assuming tests are run in online mode
......
...@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401 ...@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): ...@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI): async def test_response_format_json_schema(client: openai.AsyncOpenAI):
prompt = 'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for _ in range(2): for _ in range(2):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
"role": "role": "user",
"user", "content": prompt
"content": ('what is 1+1? please respond with a JSON object, ' }],
'the format is {"result": 2}') )
content = resp.choices[0].message.content
assert content is not None
with pytest.raises((json.JSONDecodeError, AssertionError)):
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": prompt
}], }],
response_format={ response_format={
"type": "json_schema", "type": "json_schema",
......
...@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, ...@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
......
...@@ -22,12 +22,12 @@ class MockHFConfig: ...@@ -22,12 +22,12 @@ class MockHFConfig:
@dataclass @dataclass
class MockModelConfig: class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME tokenizer = MODEL_NAME
trust_remote_code = False trust_remote_code = False
tokenizer_mode = "auto" tokenizer_mode = "auto"
max_model_len = 100 max_model_len = 100
tokenizer_revision = None tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "meta-llama/Llama-3.2-1B"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [ ...@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
"--task",
"generate",
"--dtype", "--dtype",
"bfloat16", "bfloat16",
"--max-model-len", "--max-model-len",
......
...@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" ...@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def phi3v_model_config(): def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID, return ModelConfig(PHI3V_MODEL_ID,
PHI3V_MODEL_ID, task="generate",
tokenizer=PHI3V_MODEL_ID,
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=True, trust_remote_code=True,
dtype="bfloat16", dtype="bfloat16",
...@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ...@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text": "What about these two?" "text": "What about these two?"
}] }]
}], phi3v_model_config, phi3v_tokenizer) }], phi3v_model_config, phi3v_tokenizer)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
...@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch): ...@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name) override_backend_env_variable(monkeypatch, name)
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True): with patch("vllm.attention.selector.current_platform.is_cpu",
backend = which_attn_to_use(16, None, torch.float16, torch.float16, return_value=True):
16, False) backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA" assert backend.name == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True): with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
16, False) False)
assert backend.name == "ROCM_FLASH" assert backend.name == "ROCM_FLASH"
elif device == "openvino": elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True): with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
16, False) False)
assert backend.name == "OPENVINO" assert backend.name == "OPENVINO"
else: else:
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == name assert backend.name == name
...@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch): ...@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)): with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False) backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type # Unsupported data type
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False) backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type # Unsupported kv cache data type
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False) backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size # Unsupported block size
backend = which_attn_to_use(16, None, torch.float16, None, 8, False) backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed # flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}): with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False) backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size # Unsupported head size
backend = which_attn_to_use(17, None, torch.float16, None, 16, False) backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention # Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
True)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL
...@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch): ...@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid.""" """Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL) override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError): with pytest.raises(ValueError):
which_attn_to_use(16, None, torch.float16, None, 16, False) which_attn_to_use(16, torch.float16, None, 16, False)
\ No newline at end of file
...@@ -78,6 +78,7 @@ def ref_paged_attn( ...@@ -78,6 +78,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
kv_lens: List[int], kv_lens: List[int],
...@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv( ...@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
sliding_window: Optional[int],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
seed_everything(0) seed_everything(0)
...@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv( ...@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens) max_kv_len = max(kv_lens)
scale = head_size**-0.5 scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks, key_cache = torch.randn(num_blocks,
...@@ -121,10 +125,10 @@ def test_flash_attn_with_paged_kv( ...@@ -121,10 +125,10 @@ def test_flash_attn_with_paged_kv(
block_table=block_tables, block_table=block_tables,
cache_seqlens=kv_lens_tensor, cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1) ).squeeze(1)
ref_output = ref_paged_attn( ref_output = ref_paged_attn(query=query,
query=query,
key_cache=key_cache, key_cache=key_cache,
value_cache=value_cache, value_cache=value_cache,
query_lens=[1] * num_seqs, query_lens=[1] * num_seqs,
...@@ -132,7 +136,7 @@ def test_flash_attn_with_paged_kv( ...@@ -132,7 +136,7 @@ def test_flash_attn_with_paged_kv(
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap, soft_cap=soft_cap,
) sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv( ...@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
...@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv( ...@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens) max_query_len = max(query_lens)
max_kv_len = max(kv_lens) max_kv_len = max(kv_lens)
window_size = ((sliding_window, window_size = ((sliding_window - 1, 0) if sliding_window is not None else
sliding_window) if sliding_window is not None else
(-1, -1)) (-1, -1))
scale = head_size**-0.5 scale = head_size**-0.5
......
...@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor, ...@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype) w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype)) opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
return w_ref, w_q_machete, w_s, w_zp return w_ref, w_q_machete, w_s, w_zp
...@@ -153,8 +153,9 @@ def test_machete_all_schedules(shape, atype: torch.dtype, ...@@ -153,8 +153,9 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule, schedule=schedule,
) )
opcheck(torch.ops._C.machete_gemm, opcheck(
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints( torch.ops._C.machete_gemm,
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule)) w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol as our reduction dim becomes larger (more rounding error)
......
...@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm( ...@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck( opcheck(
torch.ops._C.gptq_marlin_gemm, torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1], workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce), a_input.shape[1], is_k_full, False, use_fp32_reduce),
test_utils=DEFAULT_OPCHECK_TEST_UTILS) test_utils=DEFAULT_OPCHECK_TEST_UTILS)
...@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm( ...@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert max_diff < 0.04 assert max_diff < 0.04
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m, size_n,
size_k):
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m,
size_n, size_k)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
...@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, ...@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck(torch.ops._C.gptq_marlin_24_gemm, opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type, a_input.shape[0], workspace_24.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]), b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS) test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_24_gemm( output = marlin_24_gemm_tester(
a_input, a_input,
marlin_24_q_w_comp, marlin_24_q_w_comp,
marlin_24_meta, marlin_24_meta,
......
...@@ -240,8 +240,8 @@ def test_fused_marlin_moe( ...@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad=False) requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe, opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids, (a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
2 * n, k, True, e, topk, block_size_m, True, False)) m, 2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, " @pytest.mark.skip("This test is here for the sake of debugging, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment