Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
# SPDX-License-Identifier: Apache-2.0
import base64
import io
import shutil
from tempfile import TemporaryDirectory
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from transformers import AutoConfig, AutoTokenizer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
LORA_NAME = "typeof/zephyr-7b-beta-lora"
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()
@pytest.fixture(scope="module")
def default_server_args(
zephyr_lora_files,
zephyr_lora_added_tokens_files,
) -> list[str]:
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# Prompt Embeds server args
"--enable-prompt-embeds",
"--no-enable-chunked-prefill",
]
@pytest.fixture(scope="module",
params=["", "--disable-frontend-multiprocessing"])
def server_with_prompt_embeds(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_with_prompt_embeds(server_with_prompt_embeds):
async with server_with_prompt_embeds.get_async_client() as async_client:
yield async_client
def create_dummy_embeds(num_tokens: int = 5) -> str:
"""Create dummy embeddings and return them as base64 encoded string."""
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
buffer = io.BytesIO()
torch.save(dummy_embeds, buffer)
return base64.b64encode(buffer.getvalue()).decode('utf-8')
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completions_with_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
# Test case: Single prompt embeds input
encoded_embeds = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None
# Test case: batch completion with prompt_embeds
encoded_embeds2 = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
assert len(completion.choices) == 2
assert len(completion.choices[0].text) >= 1
assert len(completion.choices[1].text) >= 1
# Test case: streaming with prompt_embeds
encoded_embeds = create_dummy_embeds()
single_completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
single_output = single_completion.choices[0].text
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": encoded_embeds})
chunks = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output
# Test case: batch streaming with prompt_embeds
encoded_embeds2 = create_dummy_embeds()
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
chunks_stream_embeds: list[list[str]] = [[], []]
finish_reason_count = 0
async for chunk in stream:
chunks_stream_embeds[chunk.choices[0].index].append(
chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 2
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert len(chunks_stream_embeds[0]) > 0
assert len(chunks_stream_embeds[1]) > 0
# Test case: mixed text and prompt_embeds
encoded_embeds = create_dummy_embeds()
completion_mixed = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="This is a prompt",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
assert len(completion.choices) == 2
completion_text_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="This is a prompt",
max_tokens=5,
temperature=0.0,
)
completion_embeds_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
# Embeddings responses should be handled first
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
0].text
assert completion_mixed.choices[1].text == completion_text_only.choices[
0].text
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completions_errors_with_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
# Test error case: invalid prompt_embeds
with pytest.raises(BadRequestError):
await client_with_prompt_embeds.completions.create(
prompt="",
model=model_name,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": "invalid_base64"})
@pytest.mark.asyncio
@pytest.mark.parametrize("logprobs_arg", [1, 0])
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completions_with_logprobs_and_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
model_name: str):
# Test case: Logprobs using prompt_embeds
encoded_embeds = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
echo=False,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": encoded_embeds})
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) == 5
assert len(logprobs.token_logprobs) == 5
assert len(logprobs.top_logprobs) == 5
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5
# Test case: Log probs with batch completion and prompt_embeds
encoded_embeds2 = create_dummy_embeds()
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
echo=False,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
assert len(completion.choices) == 2
for choice in completion.choices:
logprobs = choice.logprobs
assert logprobs is not None
assert len(logprobs.text_offset) == 5
assert len(logprobs.token_logprobs) == 5
assert len(logprobs.top_logprobs) == 5
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5
...@@ -11,7 +11,7 @@ import requests ...@@ -11,7 +11,7 @@ import requests
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.embedding.utils import correctness_test from ...models.utils import run_embedding_correctness_test
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
...@@ -76,7 +76,7 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, ...@@ -76,7 +76,7 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 11 assert embeddings.usage.total_tokens == 11
vllm_outputs = [d.embedding for d in embeddings.data] vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs) run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
# test using token IDs # test using token IDs
input_tokens = [1, 1, 1, 1, 1] input_tokens = [1, 1, 1, 1, 1]
...@@ -121,7 +121,7 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, ...@@ -121,7 +121,7 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 33 assert embeddings.usage.total_tokens == 33
vllm_outputs = [d.embedding for d in embeddings.data] vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, input_texts, vllm_outputs) run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
# test list[list[int]] # test list[list[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
...@@ -208,7 +208,7 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, ...@@ -208,7 +208,7 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
model=model_name, model=model_name,
encoding_format="float") encoding_format="float")
float_data = [d.embedding for d in responses_float.data] float_data = [d.embedding for d in responses_float.data]
correctness_test(hf_model, input_texts, float_data) run_embedding_correctness_test(hf_model, input_texts, float_data)
responses_base64 = await client.embeddings.create(input=input_texts, responses_base64 = await client.embeddings.create(input=input_texts,
model=model_name, model=model_name,
...@@ -219,13 +219,13 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, ...@@ -219,13 +219,13 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
np.frombuffer(base64.b64decode(data.embedding), np.frombuffer(base64.b64decode(data.embedding),
dtype="float32").tolist()) dtype="float32").tolist())
correctness_test(hf_model, input_texts, base64_data) run_embedding_correctness_test(hf_model, input_texts, base64_data)
# Default response is float32 decoded from base64 by OpenAI Client # Default response is float32 decoded from base64 by OpenAI Client
responses_default = await client.embeddings.create(input=input_texts, responses_default = await client.embeddings.create(input=input_texts,
model=model_name) model=model_name)
default_data = [d.embedding for d in responses_default.data] default_data = [d.embedding for d in responses_default.data]
correctness_test(hf_model, input_texts, default_data) run_embedding_correctness_test(hf_model, input_texts, default_data)
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from ...conftest import HfRunner from ...conftest import HfRunner
from ...models.embedding.utils import EmbedModelInfo, correctness_test from ...models.utils import EmbedModelInfo, run_embedding_correctness_test
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODELS = [ MODELS = [
...@@ -95,7 +95,8 @@ async def test_matryoshka(model_info: EmbedModelInfo, ...@@ -95,7 +95,8 @@ async def test_matryoshka(model_info: EmbedModelInfo,
assert len(embeddings.data[0].embedding) == dimensions assert len(embeddings.data[0].embedding) == dimensions
vllm_outputs = [d.embedding for d in embeddings.data] vllm_outputs = [d.embedding for d in embeddings.data]
correctness_test(hf_model, prompts, vllm_outputs, dimensions) run_embedding_correctness_test(hf_model, prompts, vllm_outputs,
dimensions)
if model_info.is_matryoshka: if model_info.is_matryoshka:
valid_dimensions: list[Optional[int]] = [None] valid_dimensions: list[Optional[int]] = [None]
......
...@@ -44,6 +44,6 @@ schema = schemathesis.from_pytest_fixture("get_schema") ...@@ -44,6 +44,6 @@ schema = schemathesis.from_pytest_fixture("get_schema")
@schema.parametrize() @schema.parametrize()
@schema.override(headers={"Content-Type": "application/json"}) @schema.override(headers={"Content-Type": "application/json"})
async def test_openapi_stateless(case): def test_openapi_stateless(case: schemathesis.Case):
#No need to verify SSL certificate for localhost #No need to verify SSL certificate for localhost
await case.call_and_validate(verify=False) case.call_and_validate(verify=False)
...@@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config(): ...@@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config():
assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig()
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
# Test cache_salt
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
)
# By default cache_salt in the engine prompt is not set
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
...@@ -145,6 +145,83 @@ async def test_tokenize_chat( ...@@ -145,6 +145,83 @@ async def test_tokenize_chat(
} }
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_chat_with_tools(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_generation in [False, True]:
for add_special in [False, True]:
conversation = [{
"role":
"user",
"content":
"What's the weather like in Paris today?",
}]
tools = [{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string"
}
},
},
},
}]
for continue_final in [False, True]:
if add_generation and continue_final:
continue
if continue_final:
conversation.append({
"role": "assistant",
"content": "Sure,"
})
prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
continue_final_message=continue_final,
conversation=conversation,
tools=tools,
tokenize=False,
)
tokens = tokenizer.encode(prompt,
add_special_tokens=add_special)
response = requests.post(
server.url_for("tokenize"),
json={
"add_generation_prompt": add_generation,
"continue_final_message": continue_final,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name,
"tools": tools,
},
)
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192,
}
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name,tokenizer_name", "model_name,tokenizer_name",
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import openai
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128
input = """Immerse yourself in the enchanting chronicle of calculus, a
mathematical domain that has radically transformed our comprehension of
change and motion. Despite its roots in ancient civilizations, the
formal birth of calculus predominantly occurred in the 17th century,
primarily under the influential guidance of Sir Isaac Newton and Gottfried
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
ancient Greek mathematics,most notably in the works of Eudoxus and
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
technique for computing areas and volumes through the use of finite sums.
This methodology laid crucial foundational work for integral calculus.
In the 17th century, both Newton and Leibniz independently pioneered
calculus, each contributing unique perspectives that would shape this new
field."""
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"embed",
"--dtype",
"bfloat16",
"--enforce-eager",
"--max-model-len",
str(max_model_len),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
truncation_size = 10
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}
response = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})
assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
truncation_size = max_model_len + 1
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}
with pytest.raises(openai.BadRequestError) as err:
err = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})
assert str(err) == f"""openai.BadRequestError:
Error code: 400 - {{'object': 'error',
'message': 'truncate_prompt_tokens value
({truncation_size})
is greater than max_model_len ({max_model_len}).
Please, select a smaller truncation size.',
'type': 'BadRequestError',
'param': None, 'code': 400}}"""
@pytest.mark.asyncio
async def test_max_truncation_size(client: openai.AsyncOpenAI):
truncation_size = -1
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size
}
response = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})
assert response["usage"]["prompt_tokens"] == max_model_len
...@@ -32,7 +32,7 @@ class StreamingToolReconstructor: ...@@ -32,7 +32,7 @@ class StreamingToolReconstructor:
assert len(delta.tool_calls) < 2, ( assert len(delta.tool_calls) < 2, (
"Streaming should include only one tool call per update.") "Streaming should include only one tool call per update.")
for call_delta in delta.tool_calls: for call_delta in delta.tool_calls:
assert call_delta.type == "function", ( assert call_delta.type is None or call_delta.type == "function", (
"Streaming tool calls should only emit function calls. Got " "Streaming tool calls should only emit function calls. Got "
f"{call_delta.type}") f"{call_delta.type}")
current_tool_call = self.tool_calls[ current_tool_call = self.tool_calls[
......
...@@ -4,8 +4,6 @@ import warnings ...@@ -4,8 +4,6 @@ import warnings
from typing import Optional from typing import Optional
import pytest import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict ...@@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64 from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples" EXAMPLES_DIR = VLLM_PATH / "examples"
...@@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ...@@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
...@@ -793,10 +793,10 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ...@@ -793,10 +793,10 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
) )
vllm_result = apply_hf_chat_template( vllm_result = apply_hf_chat_template(
tokenizer, tokenizer=tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
chat_template=None, chat_template=None,
model_config=model_config,
tools=None, tools=None,
add_generation_prompt=True, add_generation_prompt=True,
) )
...@@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ...@@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
@pytest.mark.parametrize("use_tools", [True, False]) @pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models.""" """checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
# Build the tokenizer group and grab the underlying tokenizer # Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
...@@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
...@@ -837,7 +848,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -837,7 +848,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
tokenizer, tokenizer,
chat_template=None, chat_template=None,
tools=tools, tools=tools,
trust_remote_code=True, model_config=model_config,
) )
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
...@@ -857,15 +868,23 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -857,15 +868,23 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
) )
# yapf: enable # yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format): def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version( model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
"4.49.0"): model_info.check_available_online(on_fail="skip")
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model, model,
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
...@@ -874,7 +893,7 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -874,7 +893,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
tokenizer, tokenizer,
chat_template=None, chat_template=None,
tools=None, tools=None,
trust_remote_code=True, model_config=model_config,
) )
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
...@@ -888,7 +907,66 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -888,7 +907,66 @@ def test_resolve_content_format_hf_defined(model, expected_format):
None, None,
"auto", "auto",
tokenizer, tokenizer,
trust_remote_code=True, model_config=model_config,
)
assert resolved_format == expected_format
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[("Salesforce/blip2-opt-2.7b", "string"),
("facebook/chameleon-7b", "string"),
("deepseek-ai/deepseek-vl2-tiny", "string"),
("microsoft/Florence-2-base", "string"),
("adept/fuyu-8b", "string"),
("google/paligemma-3b-mix-224", "string"),
("Qwen/Qwen-VL", "string"),
("Qwen/Qwen-VL-Chat", "string")],
)
# yapf: enable
def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup(
model_config.tokenizer,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
...@@ -899,17 +977,13 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -899,17 +977,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
("template_path", "expected_format"), ("template_path", "expected_format"),
[("template_alpaca.jinja", "string"), [("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"), ("template_baichuan.jinja", "string"),
("template_blip2.jinja", "string"),
("template_chatglm.jinja", "string"), ("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"), ("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"), ("template_chatml.jinja", "string"),
("template_deepseek_vl2.jinja", "string"),
("template_dse_qwen2_vl.jinja", "openai"), ("template_dse_qwen2_vl.jinja", "openai"),
("template_falcon_180b.jinja", "string"), ("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"), ("template_falcon.jinja", "string"),
("template_florence2.jinja", "string"),
("template_inkbot.jinja", "string"), ("template_inkbot.jinja", "string"),
("template_llava.jinja", "string"),
("template_teleflm.jinja", "string"), ("template_teleflm.jinja", "string"),
("template_vlm2vec.jinja", "openai"), ("template_vlm2vec.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_granite_20b_fc.jinja", "string"),
...@@ -922,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -922,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format):
) )
# yapf: enable # yapf: enable
def test_resolve_content_format_examples(template_path, expected_format): def test_resolve_content_format_examples(template_path, expected_format):
model_config = ModelConfig(
PHI3V_MODEL_ID, # Dummy
tokenizer=PHI3V_MODEL_ID, # Dummy
trust_remote_code=True,
)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
PHI3V_MODEL_ID, PHI3V_MODEL_ID, # Dummy
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer = tokenizer_group.tokenizer
dummy_tokenizer.chat_template = None dummy_tokenizer.chat_template = None
...@@ -944,7 +1025,7 @@ def test_resolve_content_format_examples(template_path, expected_format): ...@@ -944,7 +1025,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
None, None,
"auto", "auto",
dummy_tokenizer, dummy_tokenizer,
trust_remote_code=True, model_config=model_config,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
...@@ -102,7 +102,10 @@ def test_env( ...@@ -102,7 +102,10 @@ def test_env(
block_size, block_size,
False, False,
use_mla=use_mla) use_mla=use_mla)
assert backend.get_name() == name if use_v1 and name != "TRITON_MLA":
assert backend.get_name() == f"{name}_VLLM_V1"
else:
assert backend.get_name() == name
else: else:
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
get_attn_backend(16, get_attn_backend(16,
...@@ -185,8 +188,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): ...@@ -185,8 +188,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch # Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: monkeypatch.setattr(torch.cuda,
(7, 5)) "get_device_capability",
lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16, False) backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
......
...@@ -5,11 +5,11 @@ import random ...@@ -5,11 +5,11 @@ import random
import pytest import pytest
import torch import torch
import triton
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_supported) is_flashmla_supported)
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
......
...@@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True) False, True)
assert backend.get_name() == "ROCM_AITER_MLA" assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
...@@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True) False, True)
assert backend.get_name() == "ROCM_AITER_MLA" assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.platforms import current_platform
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: list[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None and soft_cap > 0:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
q_dtype: Optional[torch.dtype],
) -> None:
torch.set_default_device("cuda")
if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32:
pytest.skip("block size must be at least 32 for fp8")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = torch.empty_like(query)
maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = None # Not yet supported
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)
unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
v=maybe_quantized_value_cache,
out=output,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
...@@ -21,6 +21,7 @@ SEEDS = [0] ...@@ -21,6 +21,7 @@ SEEDS = [0]
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
USE_KEY = [True, False]
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
...@@ -28,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, ...@@ -28,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
return (batch_size, seq_len, num_heads * head_size) return (batch_size, seq_len, num_heads * head_size)
# For testing sliced tensors
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size + 64)
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]: head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size) return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] TENSORS_SHAPES_FN = [
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
...@@ -46,6 +55,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] ...@@ -46,6 +55,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode() @torch.inference_mode()
def test_rotary_embedding( def test_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
...@@ -58,6 +68,7 @@ def test_rotary_embedding( ...@@ -58,6 +68,7 @@ def test_rotary_embedding(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,
use_key: bool,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
...@@ -74,7 +85,11 @@ def test_rotary_embedding( ...@@ -74,7 +85,11 @@ def test_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype) query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place. # because the custom kernel is in-place.
...@@ -85,10 +100,14 @@ def test_rotary_embedding( ...@@ -85,10 +100,14 @@ def test_rotary_embedding(
ref_query, ref_query,
atol=get_default_atol(out_query), atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query)) rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key, if use_key:
ref_key, torch.testing.assert_close(out_key,
atol=get_default_atol(out_key), ref_key,
rtol=get_default_rtol(out_key)) atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
...@@ -101,6 +120,7 @@ def test_rotary_embedding( ...@@ -101,6 +120,7 @@ def test_rotary_embedding(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode() @torch.inference_mode()
def test_batched_rotary_embedding( def test_batched_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
...@@ -113,6 +133,7 @@ def test_batched_rotary_embedding( ...@@ -113,6 +133,7 @@ def test_batched_rotary_embedding(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,
use_key: bool,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
...@@ -129,7 +150,11 @@ def test_batched_rotary_embedding( ...@@ -129,7 +150,11 @@ def test_batched_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype) query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place. # because the custom kernel is in-place.
...@@ -145,10 +170,14 @@ def test_batched_rotary_embedding( ...@@ -145,10 +170,14 @@ def test_batched_rotary_embedding(
ref_query, ref_query,
atol=get_default_atol(out_query), atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query)) rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key, if use_key:
ref_key, torch.testing.assert_close(out_key,
atol=get_default_atol(out_key), ref_key,
rtol=get_default_rtol(out_key)) atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
...@@ -160,6 +189,7 @@ def test_batched_rotary_embedding( ...@@ -160,6 +189,7 @@ def test_batched_rotary_embedding(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode() @torch.inference_mode()
def test_batched_rotary_embedding_multi_lora( def test_batched_rotary_embedding_multi_lora(
is_neox_style: bool, is_neox_style: bool,
...@@ -171,6 +201,7 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -171,6 +201,7 @@ def test_batched_rotary_embedding_multi_lora(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,
use_key: bool,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
...@@ -190,7 +221,7 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -190,7 +221,7 @@ def test_batched_rotary_embedding_multi_lora(
seq_len, seq_len,
num_heads * head_size, num_heads * head_size,
dtype=dtype) dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query) if use_key else None
offset_map = torch.tensor( offset_map = torch.tensor(
list( list(
...@@ -214,10 +245,14 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -214,10 +245,14 @@ def test_batched_rotary_embedding_multi_lora(
ref_query, ref_query,
atol=get_default_atol(out_query), atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query)) rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key, if use_key:
ref_key, torch.testing.assert_close(out_key,
atol=get_default_atol(out_key), ref_key,
rtol=get_default_rtol(out_key)) atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@torch.inference_mode() @torch.inference_mode()
......
...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding ...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_opcheck(rot, def rotary_embedding_opcheck(rot,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None): offsets: Optional[torch.Tensor] = None):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
...@@ -37,9 +37,11 @@ def rotary_embedding_opcheck(rot, ...@@ -37,9 +37,11 @@ def rotary_embedding_opcheck(rot,
@pytest.mark.parametrize("rotary_dim", [32]) @pytest.mark.parametrize("rotary_dim", [32])
@pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contingous", [True, False])
def test_rotary_embedding_opcheck(dist_init, device, max_position, def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size, is_neox_style, rotary_dim, head_size,
seq_len): seq_len, use_key, head_stride_is_contingous):
batch_size = 1 batch_size = 1
base = 10000 base = 10000
num_heads = 7 num_heads = 7
...@@ -49,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, ...@@ -49,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions = torch.randint(0, positions = torch.randint(0,
max_position, (batch_size, seq_len), max_position, (batch_size, seq_len),
device=device) device=device)
head_stride = head_size + (64 if head_stride_is_contingous else 0)
query = torch.randn(batch_size, query = torch.randn(batch_size,
seq_len, seq_len,
num_heads * head_size, num_heads,
head_stride,
dtype=torch.float32, dtype=torch.float32,
device=device) device=device)
key = torch.randn_like(query) key = torch.randn_like(query) if use_key else None
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
rotary_embedding_opcheck(rot, positions, query, key) rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len, offsets = torch.zeros(batch_size * seq_len,
device=device, device=device,
dtype=torch.long) dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets) rotary_embedding_opcheck(rot, positions, query, key, offsets)
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
if head_stride_is_contingous:
rotary_embedding_opcheck(
rot, positions, query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None)
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import ( from vllm.model_executor.layers.mamba.mamba2_metadata import (
_seq_idx_to_chunk_indices_offsets) _query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
last_taken, exhausted, n_heads, last_taken, exhausted, n_heads,
d_head, itype): d_head, itype):
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( chunk_indices, chunk_offsets = \
seq_idx, chunk_size) _query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y, new_states = mamba_chunk_scan_combined( Y, new_states = mamba_chunk_scan_combined(
X, X,
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import pytest
import torch
import triton.language as tl
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
@dataclass
class BatchedMMConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int
@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]
@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
test_output = tensors.C
ref_output = test_output.clone()
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
test_output,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
None,
# Quantization schemes
False,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
...@@ -30,6 +30,11 @@ MNK_FACTORS = [ ...@@ -30,6 +30,11 @@ MNK_FACTORS = [
(224, 3072, 1536), (224, 3072, 1536),
] ]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclasses.dataclass @dataclasses.dataclass
class MOETensors: class MOETensors:
...@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ...@@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
'topk_weights': topk_weights, 'topk_weights': topk_weights,
'topk_ids_': topk_ids, 'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1, 'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1, 'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2, 'ab_strides2': moe_tensors.ab_strides2,
...@@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph( ...@@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch: bool, per_out_ch: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(vllm_config):
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch) per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=torch.half) score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a, topk_weights, topk_ids, _ = fused_topk(mt.a,
score, score,
topk, topk,
renormalize=False) renormalize=False)
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.
...@@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph( ...@@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch: bool, per_out_ch: bool,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(vllm_config):
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
dtype = torch.half dtype = torch.half
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch) per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(mt.a, topk_weights, topk_ids, _ = fused_topk(mt.a,
score, score,
topk, topk,
renormalize=False) renormalize=False)
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.
...@@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP( ...@@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP(
ep_size: int, ep_size: int,
): ):
current_platform.seed_everything(7) current_platform.seed_everything(7)
with set_current_vllm_config( with set_current_vllm_config(vllm_config):
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel) per_out_channel)
score = torch.randn((m, e), device="cuda", dtype=torch.half) score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids = fused_topk(mt.a, topk_weights, topk_ids, _ = fused_topk(mt.a,
score, score,
topk, topk,
renormalize=False) renormalize=False)
# Note that we are using the dequantized versions of the tensors. # Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences. # Using a, w1 and w2 directly results in minor output differences.
......
...@@ -11,24 +11,32 @@ from transformers import MixtralConfig ...@@ -11,24 +11,32 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe) fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_fp4_like)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize, marlin_quantize) awq_marlin_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights) quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import ScalarType, scalar_types
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4] EP_SIZE = [1, 4]
TOP_KS = [2, 6] TOP_KS = [2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
...@@ -67,31 +75,33 @@ def test_fused_moe( ...@@ -67,31 +75,33 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
torch_output = torch_moe(a, w1, w2, score, topk, e_map) with set_current_vllm_config(vllm_config):
iterative_output = iterative_moe(a, torch_output = torch_moe(a, w1, w2, score, topk, e_map)
w1, iterative_output = iterative_moe(a,
w2, w1,
score, w2,
topk, score,
global_num_experts=e, topk,
expert_map=e_map, global_num_experts=e,
renormalize=False) expert_map=e_map,
renormalize=False)
# Pad the weight if moe padding is enabled
if padding: # Pad the weight if moe padding is enabled
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] if padding:
torch.cuda.empty_cache() w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] torch.cuda.empty_cache()
torch.cuda.empty_cache() w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
...@@ -112,7 +122,6 @@ def test_fused_moe( ...@@ -112,7 +122,6 @@ def test_fused_moe(
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
ep_size: int, dtype: torch.dtype, group_size: int, ep_size: int, dtype: torch.dtype, group_size: int,
has_zp: bool, weight_bits: int): has_zp: bool, weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
...@@ -191,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -191,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else: else:
e_map = None e_map = None
triton_output = fused_moe(a, with set_current_vllm_config(vllm_config):
w1_qweight, triton_output = fused_moe(a,
w2_qweight, w1_qweight,
score, w2_qweight,
topk, score,
renormalize=False, topk,
use_int4_w4a16=weight_bits == 4, renormalize=False,
use_int8_w8a16=weight_bits == 8, use_int4_w4a16=weight_bits == 4,
global_num_experts=e, use_int8_w8a16=weight_bits == 8,
expert_map=e_map, global_num_experts=e,
w1_scale=w1_scales, expert_map=e_map,
w2_scale=w2_scales, w1_scale=w1_scales,
w1_zp=w1_qzeros if has_zp else None, w2_scale=w2_scales,
w2_zp=w2_qzeros if has_zp else None, w1_zp=w1_qzeros if has_zp else None,
block_shape=[0, group_size]) w2_zp=w2_qzeros if has_zp else None,
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
...@@ -221,9 +232,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -221,9 +232,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
"""Make sure our Mixtral MoE implementation agrees with the one from """Make sure our Mixtral MoE implementation agrees with the one from
huggingface.""" huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter: if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
# Instantiate our and huggingface's MoE blocks # Instantiate our and huggingface's MoE blocks
config = MixtralConfig() config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
...@@ -285,18 +303,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -285,18 +303,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol=mixtral_moe_tol[dtype]) atol=mixtral_moe_tol[dtype])
@pytest.mark.parametrize("m", [1, 33, 123]) def marlin_moe_generate_valid_test_cases():
@pytest.mark.parametrize("n", [128, 1024]) import itertools
@pytest.mark.parametrize("k", [256, 2048]) m_list = [1, 123, 666]
@pytest.mark.parametrize("e", [4, 12]) n_list = [128, 1024]
@pytest.mark.parametrize("topk", [2, 3]) k_list = [256, 2048]
@pytest.mark.parametrize("ep_size", [1, 4]) e_list = [4, 12]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) topk_list = [2, 3]
@pytest.mark.parametrize("group_size", [-1, 32, 128]) ep_size_list = [1, 4]
@pytest.mark.parametrize("act_order", [True, False]) dtype_list = [torch.half, torch.bfloat16]
@pytest.mark.parametrize("num_bits", [4, 8]) group_size_list = [-1, 16, 32, 128]
@pytest.mark.parametrize("has_zp", [True, False]) act_order_list = [True, False]
@pytest.mark.parametrize("is_k_full", [True, False]) quant_type_list = [
scalar_types.float4_e2m1f,
scalar_types.float8_e4m3fn,
scalar_types.uint4,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
is_k_full_list = [True, False]
all_combinations = itertools.product(m_list, n_list, k_list, e_list,
topk_list, ep_size_list, dtype_list,
group_size_list, act_order_list,
quant_type_list, is_k_full_list)
def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
quant_type, is_k_full):
if quant_type == scalar_types.float8_e4m3fn and \
group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False
# Filter act_order
if act_order:
if group_size in (-1, k, n):
return False
if quant_type not in [scalar_types.uint4b8]:
return False
elif not is_k_full:
return False
return True
cases = []
for case in all_combinations:
if is_invalid(*case):
cases.append(case)
return cases
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases())
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe( def test_fused_marlin_moe(
m: int, m: int,
...@@ -308,14 +372,22 @@ def test_fused_marlin_moe( ...@@ -308,14 +372,22 @@ def test_fused_marlin_moe(
dtype: torch.dtype, dtype: torch.dtype,
group_size: int, group_size: int,
act_order: bool, act_order: bool,
num_bits: int, quant_type: ScalarType,
has_zp: bool,
is_k_full: bool, is_k_full: bool,
): ):
current_platform.seed_everything(7) torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
if quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
# Filter act_order # Filter act_order
if act_order: if act_order:
if quant_type == scalar_types.float8_e4m3fn:
return
if group_size == -1: if group_size == -1:
return return
if group_size in (k, n): if group_size in (k, n):
...@@ -326,17 +398,14 @@ def test_fused_marlin_moe( ...@@ -326,17 +398,14 @@ def test_fused_marlin_moe(
if not is_k_full: if not is_k_full:
return return
if has_zp: if quant_type == scalar_types.float4_e2m1f and group_size != 16:
# we don't build kernel for int8 with zero return
if num_bits == 8: if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return return
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
if ep_size > 1: if ep_size > 1:
local_e = e // ep_size local_e = e // ep_size
...@@ -351,12 +420,27 @@ def test_fused_marlin_moe( ...@@ -351,12 +420,27 @@ def test_fused_marlin_moe(
w_ref1_l = [] w_ref1_l = []
qweight1_l = [] qweight1_l = []
scales1_l = [] scales1_l = []
global_scale1_l = []
zeros1_l = [] zeros1_l = []
g_idx1_l = [] g_idx1_l = []
sort_indices1_l = [] sort_indices1_l = []
for i in range(w1.shape[0]): for i in range(w1.shape[0]):
if has_zp: if quant_type == scalar_types.float4_e2m1f:
w_ref1, qweight1, scales1, global_scale1 = \
rand_marlin_weight_fp4_like(w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size) w1[i].transpose(1, 0), quant_type, group_size)
...@@ -366,9 +450,9 @@ def test_fused_marlin_moe( ...@@ -366,9 +450,9 @@ def test_fused_marlin_moe(
zeros1_l.append(zeros1) zeros1_l.append(zeros1)
else: else:
test_perm = torch.randperm(k) test_perm = torch.randperm(k)
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type, w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
group_size, act_order, test_perm) marlin_quantize(w1[i].transpose(1, 0), quant_type,
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res group_size, act_order, test_perm)
w_ref1_l.append(w_ref1.T) w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1) qweight1_l.append(qweight1)
...@@ -379,6 +463,7 @@ def test_fused_marlin_moe( ...@@ -379,6 +463,7 @@ def test_fused_marlin_moe(
w_ref1 = stack_and_dev(w_ref1_l) w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous() qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l) scales1 = stack_and_dev(scales1_l)
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
...@@ -386,12 +471,27 @@ def test_fused_marlin_moe( ...@@ -386,12 +471,27 @@ def test_fused_marlin_moe(
w_ref2_l = [] w_ref2_l = []
qweight2_l = [] qweight2_l = []
scales2_l = [] scales2_l = []
global_scale2_l = []
zeros2_l = [] zeros2_l = []
g_idx2_l = [] g_idx2_l = []
sort_indices2_l = [] sort_indices2_l = []
for i in range(w2.shape[0]): for i in range(w2.shape[0]):
if has_zp: if quant_type == scalar_types.float4_e2m1f:
w_ref2, qweight2, scales2, global_scale2 = \
rand_marlin_weight_fp4_like(w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size) w2[i].transpose(1, 0), quant_type, group_size)
...@@ -401,9 +501,9 @@ def test_fused_marlin_moe( ...@@ -401,9 +501,9 @@ def test_fused_marlin_moe(
zeros2_l.append(zeros2) zeros2_l.append(zeros2)
else: else:
test_perm = torch.randperm(n) test_perm = torch.randperm(n)
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type, w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
group_size, act_order, test_perm) marlin_quantize(w2[i].transpose(1, 0), quant_type,
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res group_size, act_order, test_perm)
w_ref2_l.append(w_ref2.T) w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2) qweight2_l.append(qweight2)
...@@ -414,15 +514,17 @@ def test_fused_marlin_moe( ...@@ -414,15 +514,17 @@ def test_fused_marlin_moe(
w_ref2 = stack_and_dev(w_ref2_l) w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous() qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l) scales2 = stack_and_dev(scales2_l)
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe( marlin_output = torch.ops.vllm.fused_marlin_moe(
a, a,
...@@ -435,108 +537,18 @@ def test_fused_marlin_moe( ...@@ -435,108 +537,18 @@ def test_fused_marlin_moe(
topk_ids, topk_ids,
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1, g_idx1=g_idx1,
g_idx2=g_idx2, g_idx2=g_idx2,
sort_indices1=sort_indices1, sort_indices1=sort_indices1,
sort_indices2=sort_indices2, sort_indices2=sort_indices2,
w1_zeros=zeros1, w1_zeros=zeros1,
w2_zeros=zeros2, w2_zeros=zeros2,
num_bits=num_bits, quant_type_id=quant_type.id,
is_k_full=is_k_full) is_k_full=is_k_full)
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [1, 33, 123])
@pytest.mark.parametrize("n", [128, 1024])
@pytest.mark.parametrize("k", [256, 2048])
@pytest.mark.parametrize("e", [4, 12])
@pytest.mark.parametrize("topk", [2, 3])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [-1, 32, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int,
act_order: bool, num_bits: int,
has_zp: bool, is_k_full: bool):
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return
if has_zp:
return
else:
if not is_k_full:
return
if has_zp:
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
else:
quant_type = scalar_types.uint4b8 \
if num_bits == 4 else scalar_types.uint8b128
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweight_l = []
scales_l = []
zeros_l = []
g_idx_l = []
sort_indices_l = []
for i in range(w.shape[0]):
if has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
zeros_l.append(zeros)
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweight_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
zeros = stack_and_dev(zeros_l) if zeros_l else None
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = torch.ops.vllm.single_marlin_moe(
a,
qweight,
scales,
score,
topk,
renormalize=False,
g_idx=g_idx,
sort_indices=sort_indices,
w_zeros=zeros,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref, score, topk)
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
def test_moe_align_block_size_opcheck(): def test_moe_align_block_size_opcheck():
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE permute/unpermute kernel
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
"""
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64]
TOP_KS = [2, 4, 6, 8]
EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0)
def torch_permute(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
start_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
if expert_map is not None:
is_local_expert = (expert_map[topk_ids] != -1)
not_local_expert = (expert_map[topk_ids] == -1)
topk_ids = is_local_expert * (
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
stable=True)
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
expert_first_token_offset = torch.zeros(n_local_expert + 1,
dtype=torch.int64,
device="cuda")
idx = 0
for i in range(0, n_local_expert):
cnt = 0
while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i:
cnt += 1
idx += 1
expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
valid_row_idx = []
if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map %
n_token, ...]
permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
m_indices[first_token_offset:last_token_offset] = i - 1
src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda",
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
return [
permuted_hidden_states, expert_first_token_offset,
src_row_id2dst_row_id_map, m_indices, valid_row_idx
]
else:
permuted_row_size = (topk * n_token + n_expert *
(align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
device="cuda",
dtype=hidden_states.dtype)
align_src_row_id2dst_row_id = torch.empty(n_token * topk,
device="cuda",
dtype=torch.int32)
align_expert_first_token_offset = torch.zeros_like(
expert_first_token_offset)
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
n_token_in_expert = last_token_offset - first_token_offset
align_expert_first_token_offset[
i] = align_expert_first_token_offset[
i - 1] + (n_token_in_expert + align_block_size -
1) // align_block_size * align_block_size
align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset:first_token_offset +
n_token_in_expert] % n_token
# store token in current expert with align_first_token_offset
permuted_hidden_states[align_first_token_offset:\
align_first_token_offset+n_token_in_expert,\
...] = hidden_states[\
dst_row_id2src_row_id_in_expert, ...]
# set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [
i for i in range(align_first_token_offset,
align_first_token_offset + n_token_in_expert)
]
# get align_src_row_id2dst_row_id
for i in range(n_token * topk):
eid = sorted_topk_ids[i]
if (eid >= n_local_expert):
# check token not in local expert
align_src_row_id2dst_row_id[
i] = align_expert_first_token_offset[-1]
continue
first_token_offset = expert_first_token_offset[eid]
align_first_token_offset = align_expert_first_token_offset[eid]
token_offset = i - first_token_offset
align_src_row_id2dst_row_id[
i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\
src2dst_idx].reshape((n_token, topk))
return [
permuted_hidden_states, align_expert_first_token_offset,
align_src_row_id2dst_row_id, m_indices, valid_row_idx
]
def torch_unpermute(permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor,
valid_row_idx: torch.Tensor, topk: int,
n_expert: int) -> torch.Tensor:
# ignore invalid row
mask = torch.zeros(permuted_hidden_states.shape[0],
dtype=bool,
device="cuda")
mask[valid_row_idx] = True
permuted_hidden_states[~mask] = 0
idx = src_row_id2dst_row_id_map.flatten()[
token_expert_indices.flatten()].reshape(token_expert_indices.shape)
output = permuted_hidden_states[idx, ...] * topk_weights[..., None]
output = output.sum(dim=1).to(permuted_hidden_states.dtype)
return output
@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000])
@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168])
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
n_expert: int, ep_size: int, dtype: torch.dtype,
align_block_size: Optional[int]):
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None
n_local_expert = n_expert
if (ep_size != 1):
n_local_expert, expert_map = determine_expert_map(
ep_size, ep_rank, n_expert)
expert_map = expert_map.cuda()
start_expert = n_local_expert * ep_rank
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, False)
gold0, gold1, gold2, gold3, valid_row_idx = torch_permute(
hidden_states,
topk_ids,
token_expert_indices,
topk,
n_expert,
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
result0, result1, result2, result3 = moe_permute(
hidden_states, topk_weights, topk_ids, token_expert_indices, topk,
n_expert, n_local_expert, expert_map, align_block_size,
fill_invalid_expert)
# check expert_first_token_offset
torch.testing.assert_close(gold1, result1, atol=0, rtol=0)
# check src_row_id2dst_row_id_map
torch.testing.assert_close(gold2, result2, atol=0, rtol=0)
# check mindice
torch.testing.assert_close(gold3, result3, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold0[valid_row_idx],
result0[valid_row_idx],
atol=0,
rtol=0)
# add a random tensor to simulate group gemm
result0 = 0.5 * result0 + torch.randn_like(result0)
result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1,
topk, n_expert, n_local_expert)
gold4 = torch_unpermute(result0, topk_weights, topk_ids,
token_expert_indices, result2, valid_row_idx, topk,
n_local_expert)
# check unpermuted hidden
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment