Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
...@@ -34,6 +34,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 ...@@ -34,6 +34,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
f"zephyr-lora2={zephyr_lora_added_tokens_files}", f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--enable-tokenizer-info-endpoint",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
...@@ -285,3 +286,106 @@ async def test_detokenize( ...@@ -285,3 +286,106 @@ async def test_detokenize(
response.raise_for_status() response.raise_for_status()
assert response.json() == {"prompt": prompt} assert response.json() == {"prompt": prompt}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenizer_info_basic(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
"""Test basic tokenizer info endpoint functionality."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
assert "tokenizer_class" in result
assert isinstance(result["tokenizer_class"], str)
assert result["tokenizer_class"]
@pytest.mark.asyncio
async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
"""Test that the response matches expected schema types."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
field_types = {
"add_bos_token": bool,
"add_prefix_space": bool,
"clean_up_tokenization_spaces": bool,
"split_special_tokens": bool,
"bos_token": str,
"eos_token": str,
"pad_token": str,
"unk_token": str,
"chat_template": str,
"errors": str,
"model_max_length": int,
"additional_special_tokens": list,
"added_tokens_decoder": dict,
}
for field, expected_type in field_types.items():
if field in result and result[field] is not None:
assert isinstance(
result[field],
expected_type), (f"{field} should be {expected_type.__name__}")
@pytest.mark.asyncio
async def test_tokenizer_info_added_tokens_structure(
server: RemoteOpenAIServer, ):
"""Test added_tokens_decoder structure if present."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
added_tokens = result.get("added_tokens_decoder")
if added_tokens:
for token_id, token_info in added_tokens.items():
assert isinstance(token_id, str), "Token IDs should be strings"
assert isinstance(token_info, dict), "Token info should be a dict"
assert "content" in token_info, "Token info should have content"
assert "special" in token_info, (
"Token info should have special flag")
assert isinstance(token_info["special"],
bool), ("Special flag should be boolean")
@pytest.mark.asyncio
async def test_tokenizer_info_consistency_with_tokenize(
server: RemoteOpenAIServer, ):
"""Test that tokenizer info is consistent with tokenization endpoint."""
info_response = requests.get(server.url_for("tokenizer_info"))
info_response.raise_for_status()
info = info_response.json()
tokenize_response = requests.post(
server.url_for("tokenize"),
json={
"model": MODEL_NAME,
"prompt": "Hello world!"
},
)
tokenize_response.raise_for_status()
tokenize_result = tokenize_response.json()
info_max_len = info.get("model_max_length")
tokenize_max_len = tokenize_result.get("max_model_len")
if info_max_len and tokenize_max_len:
assert info_max_len >= tokenize_max_len, (
"Info max length should be >= tokenize max length")
@pytest.mark.asyncio
async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
"""Test chat template is properly included."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
chat_template = result.get("chat_template")
if chat_template:
assert isinstance(chat_template,
str), ("Chat template should be a string")
assert chat_template.strip(), "Chat template should not be empty"
\ No newline at end of file
...@@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset ...@@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode", "mistral", "--config_format", "mistral",
"--load_format", "mistral"
]
@pytest.fixture @pytest.fixture
def mary_had_lamb(): def mary_had_lamb():
...@@ -33,9 +38,15 @@ def winning_call(): ...@@ -33,9 +38,15 @@ def winning_call():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb): @pytest.mark.parametrize(
model_name = "openai/whisper-large-v3-turbo" "model_name",
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
server_args = ["--enforce-eager"] server_args = ["--enforce-eager"]
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server: with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
...@@ -65,10 +76,13 @@ async def test_bad_requests(mary_had_lamb): ...@@ -65,10 +76,13 @@ async def test_bad_requests(mary_had_lamb):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb): @pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
model_name = "openai/whisper-large-v3-turbo" async def test_long_audio_request(mary_had_lamb, model_name):
server_args = ["--enforce-eager"] server_args = ["--enforce-eager"]
if model_name.startswith("openai"):
return
mary_had_lamb.seek(0) mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb) audio, sr = librosa.load(mary_had_lamb)
# Add small silence after each audio for repeatability in the split process # Add small silence after each audio for repeatability in the split process
...@@ -87,7 +101,8 @@ async def test_long_audio_request(mary_had_lamb): ...@@ -87,7 +101,8 @@ async def test_long_audio_request(mary_had_lamb):
response_format="text", response_format="text",
temperature=0.0) temperature=0.0)
out = json.loads(transcription)['text'] out = json.loads(transcription)['text']
assert out.count("Mary had a little lamb") == 10 counts = out.count("Mary had a little lamb")
assert counts == 10, counts
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -154,7 +169,8 @@ async def test_streaming_response(winning_call): ...@@ -154,7 +169,8 @@ async def test_streaming_response(winning_call):
file=winning_call, file=winning_call,
language="en", language="en",
temperature=0.0, temperature=0.0,
extra_body=dict(stream=True)) extra_body=dict(stream=True),
timeout=30)
# Reconstruct from chunks and validate # Reconstruct from chunks and validate
async for chunk in res: async for chunk in res:
# just a chunk # just a chunk
...@@ -184,7 +200,8 @@ async def test_stream_options(winning_call): ...@@ -184,7 +200,8 @@ async def test_stream_options(winning_call):
temperature=0.0, temperature=0.0,
extra_body=dict(stream=True, extra_body=dict(stream=True,
stream_include_usage=True, stream_include_usage=True,
stream_continuous_usage_stats=True)) stream_continuous_usage_stats=True),
timeout=30)
final = False final = False
continuous = True continuous = True
async for chunk in res: async for chunk in res:
......
...@@ -39,8 +39,8 @@ async def test_basic_audio(foscolo): ...@@ -39,8 +39,8 @@ async def test_basic_audio(foscolo):
# TODO remove once language detection is implemented # TODO remove once language detection is implemented
extra_body=dict(language="it"), extra_body=dict(language="it"),
temperature=0.0) temperature=0.0)
out = json.loads(translation)['text'].strip() out = json.loads(translation)['text'].strip().lower()
assert "Nor will I ever touch the sacred" in out assert "greek sea" in out
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -168,5 +168,4 @@ async def test_long_audio_request(foscolo): ...@@ -168,5 +168,4 @@ async def test_long_audio_request(foscolo):
response_format="text", response_format="text",
temperature=0.0) temperature=0.0)
out = json.loads(translation)['text'].strip().lower() out = json.loads(translation)['text'].strip().lower()
# TODO investigate higher model uncertainty in for longer translations. assert out.count("greek sea") == 2
assert out.count("nor will i ever") == 2
...@@ -45,11 +45,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [ ...@@ -45,11 +45,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
], ],
[ [
"The image shows a Venn diagram with three over", "The image shows a Venn diagram with three over",
"This image shows a Venn diagram with three over", "The image shows a Venn diagram with three intersect",
], ],
[ [
"This image displays a gradient of colors ranging from", "This image displays a gradient of colors ranging from",
"This image displays a gradient of colors transitioning from", "The image displays a gradient of colors ranging from",
], ],
] ]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from unittest.mock import MagicMock
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
def make_tool_call(name, arguments):
return ToolCall(type="function",
function=FunctionCall(name=name,
arguments=json.dumps(arguments)))
# TODO: add reason prefix and suffix.
@pytest.mark.parametrize(
"model_output,expected_tool_calls,expected_content",
[
# No tool call
("How can I help you today?", [], "How can I help you today?"),
# Single tool call, no content
(
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
[
make_tool_call("get_weather", {
"city": "San Francisco",
"metric": "celsius"
})
],
None),
# Multiple tool calls
(
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
[
make_tool_call("get_weather", {
"city": "San Francisco",
"metric": "celsius"
}),
make_tool_call(
"register_user", {
"name": "John Doe",
"age": 37,
"address": {
"city": "San Francisco",
"state": "CA"
},
"role": None,
"passed_test": True,
"aliases": ["John", "Johnny"]
})
],
None),
# Content before tool call
(
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
[make_tool_call("get_weather", {"city": "Boston"})],
"I will call the tool now. "),
# Content after tool call (should be stripped)
(
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
[make_tool_call("get_weather", {"city": "Seattle"})],
None),
(
"<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>",
[
make_tool_call(
"complex_tool",
{"level1": {
"level2": {
"level3": {
"value": 123
}
}
}})
],
None,
),
])
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
expected_content):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"hunyuan_a13b")(mock_tokenizer)
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=False)
# align the random id.
for idx in range(len(tool_calls)):
tool_calls[idx].id = expected_tool_calls[idx].id
assert tool_calls == expected_tool_calls
assert content == expected_content
# Streaming test: simulate incremental output
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
([
"<tool_calls>[{\"name\": \"get_weather\", ",
"\"arguments\": {\"city\": \"San Francisco\", ",
"\"metric\": \"celsius\"}}]", "</tool_calls>"
], [
make_tool_call("get_weather", {
"city": "San Francisco",
"metric": "celsius"
})
]),
([
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
], [make_tool_call("get_weather", {"city": "Boston"})]),
([
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
], [make_tool_call("get_weather", {"city": "Boston"})]),
pytest.param([
"<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ",
" {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}",
"]</tool_calls>"
], [
make_tool_call("complex_tool",
{"level1": {
"level2": {
"level3": {
"value": 123
}
}
}})
],
marks=pytest.mark.xfail(
reason="stream parsing not support nested json yet.")),
])
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"hunyuan_a13b")(mock_tokenizer)
reconstructor = run_tool_extraction_streaming(
tool_parser, model_deltas, assert_one_tool_per_delta=False)
# align the random id.
for idx in range(len(reconstructor.tool_calls)):
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
assert reconstructor.tool_calls == expected_tool_calls
...@@ -2,12 +2,20 @@ ...@@ -2,12 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings import warnings
from typing import Optional from collections.abc import Mapping
from typing import Literal, Optional
import pytest import pytest
import os import os
from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy,
SpecialTokens)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
Tekkenizer)
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
...@@ -16,9 +24,12 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, ...@@ -16,9 +24,12 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
resolve_hf_chat_template) resolve_hf_chat_template)
from vllm.entrypoints.llm import apply_hf_chat_template from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64 from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
encode_video_base64)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from ..utils import models_path_prefix from ..utils import models_path_prefix
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH from ..utils import VLLM_PATH
...@@ -30,11 +41,13 @@ ULTRAVOX_MODEL_ID = os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-lla ...@@ -30,11 +41,13 @@ ULTRAVOX_MODEL_ID = os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-lla
QWEN2AUDIO_MODEL_ID = os.path.join(models_path_prefix, "Qwen/Qwen2-Audio-7B-Instruct") QWEN2AUDIO_MODEL_ID = os.path.join(models_path_prefix, "Qwen/Qwen2-Audio-7B-Instruct")
QWEN2VL_MODEL_ID = os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct") QWEN2VL_MODEL_ID = os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct")
QWEN25VL_MODEL_ID = os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-3B-Instruct") QWEN25VL_MODEL_ID = os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-3B-Instruct")
QWEN25OMNI_MODEL_ID = os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-7B")
MLLAMA_MODEL_ID = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision-Instruct") MLLAMA_MODEL_ID = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision-Instruct")
LLAMA_GUARD_MODEL_ID = os.path.join(models_path_prefix, "meta-llama/Llama-Guard-3-1B") LLAMA_GUARD_MODEL_ID = os.path.join(models_path_prefix, "meta-llama/Llama-Guard-3-1B")
HERMES_MODEL_ID = os.path.join(models_path_prefix, "NousResearch/Hermes-3-Llama-3.1-8B") HERMES_MODEL_ID = os.path.join(models_path_prefix, "NousResearch/Hermes-3-Llama-3.1-8B")
MISTRAL_MODEL_ID = os.path.join(models_path_prefix, "mistralai/Mistral-Small-3.1-24B-Instruct-2503") MISTRAL_MODEL_ID = os.path.join(models_path_prefix, "mistralai/Mistral-Small-3.1-24B-Instruct-2503")
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def phi3v_model_config(): def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID, return ModelConfig(PHI3V_MODEL_ID,
...@@ -49,6 +62,21 @@ def phi3v_model_config(): ...@@ -49,6 +62,21 @@ def phi3v_model_config():
}) })
@pytest.fixture(scope="function")
def phi3v_model_config_mm_interleaved():
return ModelConfig(PHI3V_MODEL_ID,
task="generate",
tokenizer=PHI3V_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="auto",
seed=0,
interleave_mm_strings=True,
limit_mm_per_prompt={
"image": 2,
})
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def phi3v_tokenizer(): def phi3v_tokenizer():
return TokenizerGroup( return TokenizerGroup(
...@@ -59,6 +87,32 @@ def phi3v_tokenizer(): ...@@ -59,6 +87,32 @@ def phi3v_tokenizer():
) )
@pytest.fixture(scope="function")
def qwen25omni_model_config_mm_interleaved():
return ModelConfig(QWEN25OMNI_MODEL_ID,
task="generate",
tokenizer=QWEN25OMNI_MODEL_ID,
tokenizer_mode="auto",
dtype="auto",
seed=0,
interleave_mm_strings=True,
limit_mm_per_prompt={
"image": 2,
"audio": 1,
"video": 1,
})
@pytest.fixture(scope="module")
def qwen25omni_tokenizer():
return TokenizerGroup(
tokenizer_id=QWEN25OMNI_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mllama_model_config(): def mllama_model_config():
return ModelConfig(MLLAMA_MODEL_ID, return ModelConfig(MLLAMA_MODEL_ID,
...@@ -114,6 +168,20 @@ def image_url(): ...@@ -114,6 +168,20 @@ def image_url():
return f"data:image/jpeg;base64,{base64}" return f"data:image/jpeg;base64,{base64}"
@pytest.fixture(scope="module")
def video_url():
video = VideoAsset('baby_reading', 1)
base64 = encode_video_base64(video.np_ndarrays)
return f"data:video/jpeg;base64,{base64}"
@pytest.fixture(scope="module")
def audio_url():
audio = AudioAsset('mary_had_lamb')
base64 = encode_audio_base64(*audio.audio_and_sample_rate)
return f"data:audio/ogg;base64,{base64}"
def _assert_mm_data_is_image_input( def _assert_mm_data_is_image_input(
mm_data: Optional[MultiModalDataDict], mm_data: Optional[MultiModalDataDict],
image_count: int, image_count: int,
...@@ -127,6 +195,23 @@ def _assert_mm_data_is_image_input( ...@@ -127,6 +195,23 @@ def _assert_mm_data_is_image_input(
assert isinstance(image_data, list) and len(image_data) == image_count assert isinstance(image_data, list) and len(image_data) == image_count
ModalityType = Literal["image", "video", "audio"]
MultiModalDataCounts = Mapping[ModalityType, int]
def _assert_mm_data_inputs(
mm_data: Optional[MultiModalDataDict],
data_count: MultiModalDataCounts,
) -> None:
assert mm_data is not None
assert set(data_count.keys()) == (set(mm_data.keys()))
for modality, n in data_count.items():
modality_data = mm_data.get(modality)
assert modality_data is not None
assert isinstance(modality_data, list) and len(modality_data) == n
def test_parse_chat_messages_single_image( def test_parse_chat_messages_single_image(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer, phi3v_tokenizer,
...@@ -638,6 +723,277 @@ def test_parse_chat_messages_multiple_images_uncommon_input( ...@@ -638,6 +723,277 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
_assert_mm_data_is_image_input(mm_data, 2) _assert_mm_data_is_image_input(mm_data, 2)
def test_parse_chat_messages_multiple_images_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [{
"type": "text",
"text": "I need you to compare this image"
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "and this one"
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "Do they have differences?"
}]
}],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
"user",
"content":
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
"Do they have differences?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_interleave_async(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages_futures(
[{
"role":
"user",
"content": [{
"type": "text",
"text": "I need you to compare this image"
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "and this one"
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "Do they have differences?"
}]
}],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
"user",
"content":
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
"Do they have differences?"
}]
_assert_mm_data_is_image_input(await mm_data, 2)
def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's on this image?"
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Be accurate."
},
]
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content": [{
"type": "text",
"text": "What's on this image?"
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}]
}],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
"user",
"content":
"What's on this image?\n<|image_1|>\nBe accurate."
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role": "user",
"content": "What's on this image?\n<|image_2|>"
}]
_assert_mm_data_is_image_input(mm_data, 2)
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer,
image_url, video_url, audio_url):
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's on this image?"
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Now listen to this audio"
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
]
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content": [{
"type": "text",
"text": "What's on this image?"
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "And what's in the video?"
}, {
"type": "video_url",
"video_url": {
"url": video_url
}
}]
}],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
"user",
"content":
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
"Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>"
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content":
"What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n"
"And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>"
}]
_assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1})
def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
with pytest.raises(
ValueError,
match=r"Found more '<|image_1|>' placeholders in input prompt "
"than actual multimodal data items."):
parse_chat_messages(
[{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type":
"text",
"text":
"I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501
"Do they have differences?"
},
]
}],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
### Mllama currently wraps images / texts as interleaved dictionaries ### Mllama currently wraps images / texts as interleaved dictionaries
def test_mllama_single_image( def test_mllama_single_image(
mllama_model_config, mllama_model_config,
...@@ -1027,3 +1383,165 @@ def test_resolve_content_format_examples(template_path, expected_format): ...@@ -1027,3 +1383,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
) )
assert resolved_format == expected_format assert resolved_format == expected_format
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config,
mistral_tokenizer):
messages = [{
"role":
"system",
"content": [{
"type": "text",
"text": "You are a helpful assistant."
}, {
"type":
"thinking",
"closed":
True,
"thinking":
"Only return the answer when you are confident."
}]
}, {
"role": "user",
"content": "What is 2+2?"
}, {
"role":
"assistant",
"content": [{
"type": "text",
"text": "Let me think about it."
}, {
"type": "thinking",
"closed": True,
"thinking": "2+2 = 4"
}, {
"type": "text",
"text": "The answer is 4.",
}],
}]
conversation_with_thinking, _ = parse_chat_messages(
messages,
mistral_model_config,
mistral_tokenizer,
content_format="openai",
)
expected_conversation = [{
"role":
"system",
"content": [{
"type": "text",
"text": "You are a helpful assistant."
}, {
"type": "text",
"text": "Only return the answer when you are confident."
}],
}, {
"role":
"user",
"content": [{
"type": "text",
"text": "What is 2+2?"
}],
}, {
"role":
"assistant",
"content": [
{
"type": "text",
"text": "Let me think about it."
},
{
"type": "text",
"text": "2+2 = 4"
},
{
"type": "text",
"text": "The answer is 4."
},
]
}]
assert conversation_with_thinking == expected_conversation
def test_apply_mistral_chat_template_thinking_chunk():
# Moved import here to avoid yapf and isort conflicts
from vllm.entrypoints.chat_utils import apply_mistral_chat_template
messages = [{
"role":
"system",
"content": [{
"type": "text",
"text": "You are a helpful assistant."
}, {
"type":
"thinking",
"closed":
True,
"thinking":
"Only return the answer when you are confident."
}]
}, {
"role": "user",
"content": "What is 2+2?"
}, {
"role":
"assistant",
"content": [{
"type": "text",
"text": "Let me think about it."
}, {
"type": "thinking",
"closed": True,
"thinking": "2+2 = 4"
}, {
"type": "text",
"text": "The answer is 4.",
}],
}, {
"role": "user",
"content": "Thanks, what is 3+3?"
}]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507")
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
tokens_ids = apply_mistral_chat_template(mistral_tokenizer,
messages,
chat_template=None,
tools=None)
string_tokens = mistral_tokenizer.mistral.decode(
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP)
expected_tokens = (
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
r"[INST]What is 2+2?[/INST]"
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
r"[INST]Thanks, what is 3+3?[/INST]")
assert string_tokens == expected_tokens
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import flashinfer
import pytest
import torch
from vllm.platforms import current_platform
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
allow_module_level=True)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
HEAD_SIZES = [128]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
SOFT_CAPS = [None, 30.0, 50.0]
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("kv_layout", ["HND"])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
kv_layout: str,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
k_scale = v_scale = 1.0
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout,
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = wrapper.run(query, key_value_cache, scale)
# TRTLLM Decode
max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens,
dtype=torch.int,
device=query.device)
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query.contiguous(),
key_value_cache,
workspace_buffer,
num_query_heads,
num_kv_heads,
scale,
block_tables,
kv_lens_tensor,
block_size,
max_kv_len,
"auto",
k_scale,
v_scale,
)
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - output_trtllm))}"
...@@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# change the attention backend to triton MLA # change the attention backend to triton MLA
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576,
False, True) torch.bfloat16,
"auto",
16,
False,
use_mla=True)
assert (backend.get_name() == "TRITON_MLA" assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1") or backend.get_name() == "TRITON_MLA_VLLM_V1")
...@@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# If use_mla is true # If use_mla is true
# The selected backend is triton MLA # The selected backend is triton MLA
m.setenv(STR_BACKEND_ENV_VAR, None) m.setenv(STR_BACKEND_ENV_VAR, None)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576,
False, True) torch.bfloat16,
"auto",
16,
False,
use_mla=True)
assert (backend.get_name() == "TRITON_MLA" assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1") or backend.get_name() == "TRITON_MLA_VLLM_V1")
# # change the attention backend to AITER MLA # change the attention backend to AITER MLA
# m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") # m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, # backend = get_attn_backend(576,
# False, True) # torch.bfloat16,
# "auto",
# 1,
# False,
# use_mla=True)
# assert (backend.get_name() == "ROCM_AITER_MLA" # assert (backend.get_name() == "ROCM_AITER_MLA"
# or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") # or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
...@@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# # The selected backend is ROCM_AITER_MLA # # The selected backend is ROCM_AITER_MLA
# m.setenv(STR_BACKEND_ENV_VAR, None) # m.setenv(STR_BACKEND_ENV_VAR, None)
# m.setenv("VLLM_ROCM_USE_AITER", "1") # m.setenv("VLLM_ROCM_USE_AITER", "1")
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, # backend = get_attn_backend(576,
# False, True) # torch.bfloat16,
# "auto",
# 1,
# False,
# use_mla=True)
# assert (backend.get_name() == "ROCM_AITER_MLA" # assert (backend.get_name() == "ROCM_AITER_MLA"
# or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") # or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
\ No newline at end of file
...@@ -77,6 +77,7 @@ def ref_paged_attn( ...@@ -77,6 +77,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_decode_with_paged_kv( def test_flashinfer_decode_with_paged_kv(
kv_lens: list[int], kv_lens: list[int],
...@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
...@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=( use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4) (num_query_heads//num_kv_heads) > 4)
) )
wrapper.plan(kv_indptr, wrapper.plan(
kv_indices, kv_indptr,
kv_last_page_lens, kv_indices,
num_query_heads, kv_last_page_lens,
num_kv_heads, num_query_heads,
head_size, num_kv_heads,
block_size, head_size,
"NONE", block_size,
q_data_type=dtype, "NONE",
kv_data_type=dtype, window_left=sliding_window - 1 if sliding_window is not None else -1,
logits_soft_cap=soft_cap) q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.run(query, key_value_cache) output = wrapper.run(query, key_value_cache)
...@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], def test_flashinfer_prefill_with_paged_kv(
num_heads: tuple[int, int], seq_lens: list[tuple[int, int]],
head_size: int, dtype: torch.dtype, num_heads: tuple[int, int],
block_size: int, head_size: int,
soft_cap: Optional[float]) -> None: dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
...@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], ...@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
window_left=sliding_window - 1 if sliding_window is not None else -1,
q_data_type=dtype, q_data_type=dtype,
kv_data_type=dtype, kv_data_type=dtype,
logits_soft_cap=soft_cap, logits_soft_cap=soft_cap,
...@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], ...@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
......
...@@ -26,6 +26,7 @@ CUDA_DEVICES = [ ...@@ -26,6 +26,7 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_rms_norm( def test_rms_norm(
num_tokens: int, num_tokens: int,
...@@ -34,13 +35,17 @@ def test_rms_norm( ...@@ -34,13 +35,17 @@ def test_rms_norm(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,
strided_input: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype) layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1) layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size) scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) last_dim = 2 * hidden_size if strided_input else hidden_size
x = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale x *= scale
residual = torch.randn_like(x) * scale if add_residual else None residual = torch.randn_like(x) * scale if add_residual else None
...@@ -63,7 +68,7 @@ def test_rms_norm( ...@@ -63,7 +68,7 @@ def test_rms_norm(
else: else:
opcheck(torch.ops._C.rms_norm, opcheck(torch.ops._C.rms_norm,
(out, x, layer.weight.data, layer.variance_epsilon)) (out, x, layer.weight.data, layer.variance_epsilon))
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS) # @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) # @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
...@@ -72,6 +77,7 @@ def test_rms_norm( ...@@ -72,6 +77,7 @@ def test_rms_norm(
# @pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0]) # @pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
# @pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("strided_input", [False, True])
# def test_fused_rms_norm_quant( # def test_fused_rms_norm_quant(
# num_tokens: int, # num_tokens: int,
# hidden_size: int, # hidden_size: int,
...@@ -80,13 +86,18 @@ def test_rms_norm( ...@@ -80,13 +86,18 @@ def test_rms_norm(
# quant_scale: float, # quant_scale: float,
# seed: int, # seed: int,
# device: str, # device: str,
# strided_input: bool,
# ) -> None: # ) -> None:
# current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
# torch.set_default_device(device) # torch.set_default_device(device)
# weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1) # weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
# scale = 1 / (2 * hidden_size) # scale = 1 / (2 * hidden_size)
# x = torch.randn(num_tokens, hidden_size, dtype=dtype) # last_dim = 2 * hidden_size if strided_input else hidden_size
# x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
# x = x_base[..., :hidden_size]
# assert x.is_contiguous() != strided_input
# x *= scale # x *= scale
# if add_residual: # if add_residual:
# residual = torch.randn_like(x) * scale # residual = torch.randn_like(x) * scale
...@@ -106,9 +117,11 @@ def test_rms_norm( ...@@ -106,9 +117,11 @@ def test_rms_norm(
# # Unfused kernel is in-place so it goes second # # Unfused kernel is in-place so it goes second
# # Also use a separate clone of x to avoid modifying the input # # Also use a separate clone of x to avoid modifying the input
# x_unfused = x.clone() # x_unfused_base = x_base.clone()
# x_unfused = x_unfused_base[..., :hidden_size]
# assert x_unfused.is_contiguous() != strided_input
# torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) # torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
# torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, # torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
# quant_scale_t) # quant_scale_t)
# torch.cuda.synchronize() # torch.cuda.synchronize()
...@@ -116,7 +129,6 @@ def test_rms_norm( ...@@ -116,7 +129,6 @@ def test_rms_norm(
# residual, # residual,
# atol=1e-2, # atol=1e-2,
# rtol=1e-2) # rtol=1e-2)
# opcheck( # opcheck(
# torch.ops._C.fused_add_rms_norm_static_fp8_quant, # torch.ops._C.fused_add_rms_norm_static_fp8_quant,
# (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) # (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
...@@ -131,7 +143,7 @@ def test_rms_norm( ...@@ -131,7 +143,7 @@ def test_rms_norm(
# opcheck(torch.ops._C.rms_norm_static_fp8_quant, # opcheck(torch.ops._C.rms_norm_static_fp8_quant,
# (out_quant_fused, x, weight, quant_scale_t, 1e-6)) # (out_quant_fused, x, weight, quant_scale_t, 1e-6))
# torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32), # torch.testing.assert_close(out_quant.to(dtype=torch.float32),
# out_quant.to(dtype=torch.float32), # out_quant_fused.to(dtype=torch.float32),
# atol=1e-3, # atol=1e-3,
# rtol=1e-3) # rtol=1e-3)
...@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel( ...@@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
gate_states[..., local_rank * N:(local_rank + 1) * N], gate_states[..., local_rank * N:(local_rank + 1) * N],
) )
ref_output = mixer_single_gpu(hidden_states, gate_states) ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.allclose(output, torch.testing.assert_close(output,
ref_output[..., local_rank * N:(local_rank + 1) * N], ref_output[...,
atol=1e-3, local_rank * N:(local_rank + 1) * N],
rtol=1e-3) atol=5e-3,
rtol=1e-3)
...@@ -6,11 +6,11 @@ import torch ...@@ -6,11 +6,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
_query_start_loc_to_chunk_indices_offsets)
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
...@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# this tests the kernels on a single example (no batching) # this tests the kernels on a single example (no batching)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if itype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2
else:
atol, rtol = 8e-3, 5e-3
# set seed # set seed
batch_size = 1 # batch_size batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen # ssd_minimal_discrete requires chunk_size divide seqlen
...@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
return_final_states=True) return_final_states=True)
# just test the last in sequence # just test the last in sequence
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
# just test the last head # just test the last head
# NOTE, in the kernel we always cast states to fp32 # NOTE, in the kernel we always cast states to fp32
torch.allclose(final_state[:, -1], torch.testing.assert_close(final_state[:, -1],
final_state_min[:, -1].to(torch.float32), final_state_min[:, -1].to(torch.float32),
atol=1e-3, atol=atol,
rtol=1e-3) rtol=rtol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @pytest.mark.parametrize("itype", [torch.float32, torch.float16])
...@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# TODO: the irregular chunk size cases have some issues and require higher
# tolerance. This is to be invesigated
if chunk_size not in {8, 256}:
atol, rtol = 5e-1, 5e-1
else:
atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an # hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle # example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample last_taken: dict = {} # map: eg -> pointer to last taken sample
...@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# just test one dim and dstate # just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0] Y_min_eg = Y_min[i][:, 0, 0]
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states # update states
states = new_states states = new_states
......
...@@ -6,9 +6,8 @@ from typing import Optional ...@@ -6,9 +6,8 @@ from typing import Optional
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
...@@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, ...@@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
x = x.contiguous() x = x.contiguous()
bias = bias.contiguous() if bias is not None else None bias = bias.contiguous() if bias is not None else None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
has_initial_state, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
current_platform.seed_everything(0)
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_state:
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
has_initial_state_tensor = torch.ones(batch,
dtype=torch.bool,
device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out = causal_conv1d_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("silu_activation", [False, True])
...@@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ...@@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref) assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, None, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded # tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, @pytest.mark.parametrize("batch_size", [3])
seqlen, has_bias, def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
width, seqlen, has_bias,
silu_activation, itype): silu_activation, itype):
device = "cuda" device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
...@@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, ...@@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
# set seed # set seed
current_platform.seed_everything(0) current_platform.seed_everything(0)
batch_size = 3
padding = 5 if with_padding else 0 padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size total_entries = 10 * batch_size
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) # x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
dtype=itype).transpose(1, 2)
x_ref = x.clone() x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to( conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
...@@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, ...@@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
], ],
dim=0) dim=0)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries, conv_state = torch.randn(total_entries,
dim,
width - 1, width - 1,
dim,
device=device, device=device,
dtype=itype) dtype=itype).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone() conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone() conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, out = causal_conv1d_update(x,
conv_state, conv_state,
weight, weight,
...@@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, ...@@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
activation=activation) activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(conv_state[unused_states_bool], assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool]) conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize( @pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096]) @pytest.mark.parametrize('dim', [64, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False]) @pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, @pytest.mark.parametrize('batch', [4, 10])
silu_activation, itype): def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
device = "cuda" device = "cuda"
torch.cuda.empty_cache() torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
...@@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
# set seed # set seed
current_platform.seed_everything(0) current_platform.seed_everything(0)
seqlens = [] seqlens = []
batch_size = 4 batch_size = batch
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0 padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1 nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append( seqlens.append(
torch.diff( torch.diff(
torch.cat( torch.cat(
...@@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0) dim=0)
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, x = rearrange(
dtype=itype)[:, 4096:4096 + dim, :] torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s")[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone() x_ref = x.clone()
weight_ref = weight.clone() weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
final_states = torch.randn(total_entries, final_states = torch.randn(total_entries,
dim,
width - 1, width - 1,
dim,
device=x.device, device=x.device,
dtype=x.dtype) dtype=x.dtype).transpose(1, 2)
final_states_ref = final_states.clone() final_states_ref = final_states.clone()
has_initial_states = torch.randint(0, has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ), 2, (cumsum.shape[0] - 1, ),
...@@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
], ],
dim=-1) dim=-1)
out = causal_conv1d_fn(x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation, PAD_SLOT_ID)
out_ref = [] out_ref = []
out_ref_b = [] out_ref_b = []
...@@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, ...@@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0) out_ref_tensor = torch.cat(out_ref, dim=0)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states[state_indices], assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices], final_states_ref[state_indices],
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
padded_state_indices, has_initial_states,
final_states, activation)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from .common import Config
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
def make_config_arg_parser(description: str):
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
if pf.__name__ == s:
return pf
raise ValueError(
f"Cannot find a PrepareFinalize type that matches {s}")
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
for fe in MK_FUSED_EXPERT_TYPES:
if fe.__name__ == s:
return fe
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
def to_quant_torch_dtype(s: str) -> torch.dtype:
if s == "torch.float8_e4m3fn":
return torch.float8_e4m3fn
raise ValueError(f"Unsupported quant type {s}")
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--world-size",
type=int,
default=2,
help="Number of ranks that participate in all2all",
)
parser.add_argument(
"--pf-type",
type=to_pf_class_type,
required=True,
help=("Choose a PrepareFinalize Type : "
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
)
parser.add_argument(
"--experts-type",
type=to_experts_class_type,
required=True,
help=(f"Choose a FusedExpert type : "
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
)
parser.add_argument(
"-m",
nargs="+",
type=int,
default=[64],
help="num tokens per rank",
)
parser.add_argument(
"-k",
type=int,
default=7168,
help="hidden-size",
)
parser.add_argument(
"-n",
type=int,
default=1024,
help="N dimension of the first fused-moe matmul",
)
parser.add_argument("--num-experts",
type=int,
default=32,
help="Global num experts")
parser.add_argument("--topk",
nargs="+",
type=int,
default=[4, 1],
help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl."
)
# Quant args
parser.add_argument("--quant-dtype",
type=to_quant_torch_dtype,
help="Quant datatype")
parser.add_argument("--per-token-quantized-activations",
action='store_true',
help=("The input activations must be per-token "
"quantized"))
parser.add_argument("--per-channel-quantized-weights",
action="store_true",
help="The weights must be per-channel quantized.")
parser.add_argument("--block-shape",
nargs="+",
type=int,
help="Quantization block shape")
# Torch trace profile generation args
parser.add_argument("--torch-trace-dir-path",
type=str,
default=None,
help="Get torch trace for single execution")
return parser
def _validate_args(args: argparse.Namespace):
if args.quant_dtype is not None:
assert args.quant_dtype == torch.float8_e4m3fn
if args.block_shape is not None:
assert len(args.block_shape) == 2, (
f"block shape must have 2 elements. got {args.block_shape}")
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
assert args.world_size == 1, (
"Single GPU objects need world size set to 1")
if args.torch_trace_dir_path is not None:
from pathlib import Path
assert Path(args.torch_trace_dir_path).is_dir(), (
f"Please create {args.torch_trace_dir_path}")
def make_config(args: argparse.Namespace) -> Config:
_validate_args(args)
quant_config = None
if args.quant_dtype is not None:
quant_config = FusedMoEQuantConfig(
quant_dtype=args.quant_dtype,
per_act_token_quant=args.per_token_quantized_activations,
per_out_ch_quant=args.per_channel_quantized_weights,
block_shape=args.block_shape)
return Config(
Ms=args.m,
K=args.k,
N=args.n,
E=args.num_experts,
topks=args.topk,
dtype=torch.bfloat16, # hard-code
quant_config=quant_config,
prepare_finalize_type=args.pf_type,
fused_experts_type=args.experts_type,
fused_moe_chunk_size=args.fused_moe_chunk_size,
world_size=args.world_size,
torch_trace_dir_path=args.torch_trace_dir_path)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# Fused experts and PrepareFinalize imports
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .parallel_utils import ProcessGroupInfo
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
make_quant_fp8_weights, per_token_cast_to_fp8)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
if t is None:
return f"{name} : None"
else:
return f"{name} : {t.shape} {t.dtype} {t.device}"
@dataclass
class Config:
Ms: Union[list[int], int]
K: int
N: int
E: int
topks: Union[list[int], int]
dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig]
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
fused_moe_chunk_size: Optional[int]
world_size: int
torch_trace_dir_path: Optional[str] = None
def describe(self) -> str:
s = ""
s += "== Config: \n"
s += f" world_size={self.world_size} \n"
s += f" PF={self.prepare_finalize_type.__name__} \n"
s += f" FE={self.fused_experts_type.__name__} \n"
s += f" topk={self.topks} \n"
s += f" dtype={self.dtype} \n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
s += " Quant: \n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype} \n"
s += f" q_block_shape={self.quant_block_shape} \n"
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
s += f" q_per_act_token={self.is_per_act_token_quant} \n"
else:
s += " quant=None \n"
return s
@property
def M(self) -> int:
assert isinstance(self.Ms, int)
return self.Ms
@property
def quant_dtype(self) -> Optional[torch.dtype]:
if self.quant_config is None:
return None
return self.quant_config.quant_dtype
@property
def is_per_act_token_quant(self) -> bool:
if self.quant_config is None:
return False
return self.quant_config.per_act_token_quant
@property
def is_per_tensor_act_quant(self) -> bool:
if self.quant_config is None:
return False
return (not self.is_per_act_token_quant
and self.quant_block_shape is None)
@property
def is_per_out_ch_quant(self) -> bool:
if self.quant_config is None:
return False
return self.quant_config.per_out_ch_quant
@property
def quant_block_shape(self) -> Optional[list[int]]:
if self.quant_config is None:
return None
return self.quant_config.block_shape
@property
def topk(self) -> int:
assert isinstance(self.topks, int)
return self.topks
@property
def topk_ids_dtype(self) -> Optional[torch.dtype]:
topk_ids_dtype = None
if self.prepare_finalize_type == PplxPrepareAndFinalize:
topk_ids_dtype = torch.uint32
elif self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]:
topk_ids_dtype = torch.int64
return topk_ids_dtype
@property
def num_local_experts(self) -> int:
return self.E // self.world_size
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
"""
make env data for vllm launch.
"""
vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = self.world_size
vllm_config.parallel_config.enable_expert_parallel = True
env_dict = {
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
}
if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
return vllm_config, env_dict
def is_fp8_block_quantized(self):
return (self.quant_dtype == torch.float8_e4m3fn
and self.quant_block_shape is not None)
def is_batched_prepare_finalize(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
def is_batched_fused_experts(self):
return self.fused_experts_type in [
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
]
def is_standard_fused_experts(self):
return self.fused_experts_type in [
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
TritonExperts
]
def is_fe_16bit_supported(self):
return self.fused_experts_type in [
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
NaiveBatchedExperts, TritonExperts
]
def is_fe_fp8_supported(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
BatchedTritonExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
NaiveBatchedExperts,
]
def is_fe_block_fp8_supported(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
BatchedTritonOrDeepGemmExperts,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
]
def is_fe_supports_chunking(self):
return self.fused_experts_type in [
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
TritonExperts
]
def needs_deep_gemm(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
DeepGemmExperts,
]
def needs_pplx(self):
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
def needs_deep_ep(self):
return self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
def all2all_backend(self):
if self.needs_pplx():
return "pplx"
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
return "deepep_high_throughput"
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
return "deepep_low_latency"
return "naive"
def needs_all2all(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize
]
def is_valid(self):
# Check prepare-finalize and fused-experts compatibility
if self.is_batched_prepare_finalize():
if not self.is_batched_fused_experts():
return False
else:
if not self.is_standard_fused_experts():
return False
use_chunking = self.fused_moe_chunk_size is not None
if use_chunking and not self.is_fe_supports_chunking():
return False
# Check quantization sanity
if (int(self.is_per_act_token_quant) +
int(self.is_per_tensor_act_quant) +
int(self.quant_block_shape is not None)) > 1:
# invalid quant config
return False
# check bf16 / fp16 support
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
if is_16bit and not self.is_fe_16bit_supported():
return False
# Check fp8 support
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
if is_fp8 and not self.is_fe_fp8_supported():
return False
# Check fp8 block quanization support
is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and not is_fp8:
return False
if is_block_quatized and not self.is_fe_block_fp8_supported():
return False
# deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized:
return False
# Check dependencies
if self.needs_deep_ep() and not has_deep_ep():
return False
if self.needs_deep_gemm() and not has_deep_gemm():
return False
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
return False
return True
@dataclass
class WeightTensors:
w1: torch.Tensor
w2: torch.Tensor
w1_scale: Optional[torch.Tensor]
w2_scale: Optional[torch.Tensor]
def describe(self):
s = ""
s += "== Weight Tensors: \n"
s += f' - {_describe_tensor(self.w1, "w1")} \n'
s += f' - {_describe_tensor(self.w2, "w2")} \n'
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
return s
def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device())
is_quantized = self.w1.dtype == torch.float8_e4m3fn
if is_quantized:
assert self.w1_scale is not None
assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to(
device=torch.cuda.current_device())
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
s = rank * num_local_experts
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
is_quantized = self.w1.dtype == torch.float8_e4m3fn
w1_scale, w2_scale = (None, None)
if is_quantized:
assert self.w1_scale is not None
assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :]
return WeightTensors(w1, w2, w1_scale, w2_scale)
@staticmethod
def make(config: Config) -> "WeightTensors":
if config.quant_dtype is None:
# just make normal dtype weights
w1, w2 = make_non_quant_weights(e=config.E,
n=config.N,
k=config.K,
dtype=config.dtype)
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
assert config.quant_dtype == torch.float8_e4m3fn
if not config.is_fp8_block_quantized():
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
e=config.E,
n=config.N,
k=config.K,
per_out_channel_quant=config.is_per_out_ch_quant,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
assert config.quant_block_shape is not None
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
e=config.E,
n=config.N,
k=config.K,
block_size=config.quant_block_shape,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
@dataclass
class RankTensors:
hidden_states: torch.Tensor
hidden_states_scale: Optional[torch.Tensor]
topk_weights: torch.Tensor
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor]
quant_config: Optional[FusedMoEQuantConfig]
def describe(self):
s = ""
s += "== Rank Tensors: \n"
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
return s
@staticmethod
def make_hidden_states(
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Return hidden_states
"""
m, k, dtype = (config.M, config.K, config.dtype)
a = (torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
if config.quant_dtype is None:
return a, None
# We dequant and use that as hidden_states so the tests are stable.
# quantizing and dequantizing yield slightly different results
# depending on the hardware. Here we, quantize and dequantize
# first - so further quantize and dequantize will yield the same
# values.
if config.is_per_tensor_act_quant:
a_q, a_scales = ops.scaled_fp8_quant(
a, use_per_token_if_dynamic=False)
return a_q.float().mul(a_scales).to(dtype), a_scales
if config.is_per_act_token_quant:
a_q, a_scales = ops.scaled_fp8_quant(a,
use_per_token_if_dynamic=True)
return a_q.float().mul(a_scales).to(dtype), None
assert config.quant_block_shape is not None
block_k = config.quant_block_shape[1]
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
return a_q.float().view(
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
@staticmethod
def make(config: Config, pgi: ProcessGroupInfo):
dtype = config.dtype
topk, m, _ = (config.topk, config.M, config.K)
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
config)
num_local_experts, global_num_experts = (config.num_local_experts,
config.E)
score = torch.randn((m, global_num_experts),
device="cuda",
dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
False)
topk_ids = topk_ids.to(config.topk_ids_dtype)
# distribute topk_ids evenly
for mi in range(m):
topk_ids[mi] = torch.randperm(config.E)[:topk]
topk_ids = topk_ids.to(device=torch.cuda.current_device())
expert_map = None
if config.world_size > 1:
expert_map = torch.full((global_num_experts, ),
fill_value=-1,
dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
expert_map = expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
return RankTensors(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_config=config.quant_config,
)
def reference_moe_impl(config: Config, weights: WeightTensors,
rank_tensors: RankTensors) -> torch.Tensor:
return torch_experts(a=rank_tensors.hidden_states,
w1=weights.w1,
w2=weights.w2,
topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E,
expert_map=None,
w1_scale=weights.w1_scale,
w2_scale=weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
quant_dtype=config.quant_dtype,
per_act_token_quant=config.is_per_act_token_quant,
block_shape=config.quant_block_shape,
apply_router_weights_on_input=config.topk == 1)
def make_fused_experts(
config: Config, moe: FusedMoEConfig,
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if config.fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif config.fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif config.fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif config.fused_experts_type == CutlassExpertsFp8:
use_batched_format = config.is_batched_prepare_finalize()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
kwargs = {
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": config.is_per_act_token_quant,
"per_out_ch_quant": config.is_per_out_ch_quant,
"block_shape": config.quant_block_shape,
"num_dispatchers": num_dispatchers,
"use_batched_format": use_batched_format
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
return experts
def make_modular_kernel(config: Config,
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2**math.ceil(math.log2(x))
# make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=get_tensor_model_parallel_world_size(),
dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config,
)
moe = FusedMoEConfig(
num_experts=config.E,
experts_per_token=config.topk,
hidden_dim=config.K,
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
quant_config=config.quant_config,
max_num_tokens=next_power_of_2(config.M),
)
# make modular kernel
prepare_finalize = None
if config.needs_all2all():
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
else:
prepare_finalize = MoEPrepareAndFinalizeNoEP()
fused_experts = make_fused_experts(config, moe,
prepare_finalize.num_dispatchers())
modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
return modular_kernel
def run_modular_kernel(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
config: Config,
weights: WeightTensors,
rank_tensors: RankTensors,
) -> torch.Tensor:
assert isinstance(config.Ms, int)
assert isinstance(config.topks, int)
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config)
mk_kwargs = {
"hidden_states": rank_tensors.hidden_states.clone(
), # impls might update the tensor in place
"w1": rank_weights.w1,
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": rank_tensors.topk_ids,
"expert_map": rank_tensors.expert_map,
"w1_scale": rank_weights.w1_scale,
"w2_scale": rank_weights.w2_scale,
"a1_scale": rank_tensors.hidden_states_scale,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1,
}
out = mk.forward(**mk_kwargs)
return out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from enum import Enum
from itertools import product
from typing import Optional
import torch
from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
run_modular_kernel)
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
class Result(Enum):
PASS = 1
FAIL = 2
SKIP = 3
def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
weights: WeightTensors,
):
current_platform.seed_everything(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device
weights.to_current_device()
Ms = config.Ms
assert isinstance(Ms, list)
TOPKs = config.topks
assert isinstance(TOPKs, list)
for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...")
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
def make_feature_matrix(csv_file_path: str):
from dataclasses import asdict
import pandas as pd
def add_to_results(config: Config,
success: Result,
results_df: Optional[pd.DataFrame] = None):
config_dict = asdict(config)
config_dict['prepare_finalize_type'] = config_dict[
'prepare_finalize_type'].__name__
config_dict['fused_experts_type'] = config_dict[
'fused_experts_type'].__name__
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
quant_config_dict = config_dict['quant_config']
del config_dict['quant_config']
if quant_config_dict is None:
quant_config = FusedMoEQuantConfig(None)
quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict
result_dict = config_dict | {'success': success.name}
result_df = pd.DataFrame([result_dict])
if results_df is None:
results_df = result_df
else:
results_df = pd.concat([results_df, result_df], ignore_index=True)
return results_df
Ms = [64]
Ks = [7168] # hidden sizes
Ns = [2048]
TOPKs = [[4, 1]]
Es = [32]
DTYPEs = [torch.bfloat16]
PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
FE_TYPES = MK_FUSED_EXPERT_TYPES
Q_TYPES = MK_QUANT_CONFIGS
combinations = list(
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
results_df: Optional[pd.DataFrame] = None
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
combinations): #noqa: E501
config = Config(Ms=[m],
K=k,
N=n,
E=e,
topks=topks,
dtype=dtype,
prepare_finalize_type=pf_type,
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None)
success = None
if config.is_valid():
print(f"Running config : {config.describe()} ...")
try:
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker,
vllm_config, env_dict, config,
weights)
success = Result.PASS
except Exception as _:
success = Result.FAIL
else:
success = Result.SKIP
results_df = add_to_results(config, success, results_df)
if results_df is not None:
results_df.to_csv(f"{csv_file_path}")
if __name__ == '__main__':
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description=(
"Make ModularKernel feature matrix \n"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
"-f ./feature_matrices/feature_matrix.csv"))
parser.add_argument("-f",
"--feature-matrix-csv-file-path",
type=str,
required=True,
help="File name to Generate a .csv file")
args = parser.parse_args()
csv_path = args.feature_matrix_csv_file_path
assert csv_path.endswith(
'csv'), f"Need a file path ending with .csv, got {csv_path}"
assert Path(csv_path).parent.is_dir(
), f"Cannot find parent directory for {Path(csv_path).parent}"
make_feature_matrix(args.feature_matrix_csv_file_path)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
# Fused experts and PrepareFinalize imports
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_pplx
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
if has_pplx():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
if has_deep_ep():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
MK_FUSED_EXPERT_TYPES = [
BatchedDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
DeepGemmExperts,
TritonOrDeepGemmExperts,
TritonExperts,
]
MK_QUANT_CONFIGS = [
None,
# per-channel / per-column weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
# per-channel / per-column weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
# per-tensor weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
# per-tensor weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
# block-quantized weights and 128 block per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import traceback
from typing import Any, Callable, Optional
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.utils import get_open_port
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
local_rank: int):
import tempfile
temp_file = tempfile.mkstemp()[1]
set_current_vllm_config(vllm_config)
with set_current_vllm_config(vllm_config):
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=f"file://{temp_file}",
local_rank=local_rank,
backend="nccl",
)
initialize_model_parallel(
tensor_model_parallel_size=vllm_config.parallel_config.
tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.
pipeline_parallel_size,
)
cpu_group = torch.distributed.new_group(list(range(world_size)),
backend="gloo")
return cpu_group
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
P], None],
vllm_config: Optional[VllmConfig],
env_dict: Optional[dict],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
if env_dict is not None:
os.environ.update(env_dict)
cpu_group = None
if vllm_config is not None:
cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
vllm_config,
cpu_group,
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch_with_config(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
vllm_config: VllmConfig,
env_dict: dict[Any, Any],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
vllm_config,
env_dict,
) + args,
nprocs=world_size,
join=True,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from itertools import product
from typing import Any, Callable
import torch
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from .common import Config, RankTensors, WeightTensors, make_modular_kernel
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
def do_profile(fn: Callable,
fn_kwargs: dict[Any, Any],
pgi: ProcessGroupInfo,
config: Config,
num_warmups: int = 5):
for _ in range(num_warmups):
fn(**fn_kwargs)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=True,
) as tprof:
fn(**fn_kwargs)
torch.cuda.synchronize(torch.cuda.current_device())
# TODO (varun): Add a descriptive trace file name
tprof.export_chrome_trace(
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
def profile_modular_kernel(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
config: Config,
weights: WeightTensors,
rank_tensors: RankTensors,
) -> None:
assert isinstance(config.Ms, int)
assert isinstance(config.topks, int)
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
# make modular kernel
mk = make_modular_kernel(config, vllm_config)
mk_kwargs = {
"hidden_states": rank_tensors.hidden_states,
"w1": rank_weights.w1,
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": rank_tensors.topk_ids,
"expert_map": rank_tensors.expert_map,
"w1_scale": rank_weights.w1_scale,
"w2_scale": rank_weights.w2_scale,
"a1_scale": rank_tensors.hidden_states_scale,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1,
}
do_profile(mk.forward, mk_kwargs, pgi, config)
def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
weights: WeightTensors,
):
current_platform.seed_everything(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device
weights.to_current_device()
Ms = config.Ms
assert isinstance(Ms, list)
TOPKs = config.topks
assert isinstance(TOPKs, list)
for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...")
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
def run(config: Config):
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights)
if __name__ == '__main__':
from .cli_args import make_config, make_config_arg_parser
parser = make_config_arg_parser(description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
args = parser.parse_args()
assert args.torch_trace_dir_path is not None, (
"Please pass in a directory to store torch traces")
config = make_config(args)
run(config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment