Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
"""
Repeat of tests in test_completion.py with the non-mp backend.
"""
# imports for guided decoding tests
import json
import re
import shutil
from tempfile import TemporaryDirectory
from typing import List
import jsonschema
import openai # use the official client for correctness check
import pytest
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically these adapters use a different base model,
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()
@pytest.fixture(scope="module")
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
"--disable-frontend-multiprocessing"
]
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model="zephyr-lora2",
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should appear in tokenized prompt
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should not appear in tokenized prompt
assert "vllm" not in completion.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=0,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=21,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=30,
stream=True,
)
async for chunk in stream:
...
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is an LLM?"
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
False,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
False,
})
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
else:
assert chunk.usage is None
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is not None
assert chunk.usage.prompt_tokens > 0
assert chunk.usage.completion_tokens > 0
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
if chunk.choices[0].finish_reason is not None:
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=False, stream_options=
# {"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None})
# Test stream=False, stream_options=
# {"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True})
# Test stream=False, stream_options=
# {"continuous_usage_stats": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": None})
# Test stream=False, stream_options=
# {"continuous_usage_stats": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": True})
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
@pytest.mark.asyncio
async def test_logits_bias(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
token_id = 1000
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
seed=42,
)
assert len(completion.choices[0].text) >= 5
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
add_special_tokens=False)["input_ids"]
assert all([
response == expected
for response, expected in zip(response_tokens, expected_tokens)
])
# Test ban
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
)
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
first_response = completion.choices[0].text
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token): -100
for token in response_tokens},
)
assert first_response != completion.choices[0].text
@pytest.mark.asyncio
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 1
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
allowed_ids = [21555, 21557, 21558]
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
seed=42,
extra_body=dict(allowed_token_ids=allowed_ids),
logprobs=1,
)
response_tokens = completion.choices[0].logprobs.tokens
assert len(response_tokens) == 1
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert len(completion.choices) == 3
for i in range(3):
output_json = json.loads(completion.choices[i].text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_regex):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert len(completion.choices) == 3
for i in range(3):
assert re.fullmatch(sample_regex,
completion.choices[i].text) is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_guided_choice):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in sample_guided_choice
@pytest.mark.asyncio
async def test_guided_grammar(client: openai.AsyncOpenAI,
sample_sql_statements):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_grammar=sample_sql_statements))
content = completion.choices[0].text
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(content)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
assert content.strip() == ground_truth
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
@pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=logprobs_arg)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert re.search(r"^" + prompt_text, completion.choices[0].text)
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema, sample_regex):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42,
guided_decoding_backend=guided_decoding_backend))
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=sample_regex,
guided_json=sample_json_schema))
...@@ -18,7 +18,6 @@ def embedding_server(): ...@@ -18,7 +18,6 @@ def embedding_server():
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager",
] ]
with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
......
...@@ -36,10 +36,12 @@ def test_oot_registration_for_api_server(): ...@@ -36,10 +36,12 @@ def test_oot_registration_for_api_server():
ctx = torch.multiprocessing.get_context() ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, )) server = ctx.Process(target=server_function, args=(port, ))
server.start() server.start()
MAX_SERVER_START_WAIT_S = 60
client = OpenAI( client = OpenAI(
base_url=f"http://localhost:{port}/v1", base_url=f"http://localhost:{port}/v1",
api_key="token-abc123", api_key="token-abc123",
) )
now = time.time()
while True: while True:
try: try:
completion = client.chat.completions.create( completion = client.chat.completions.create(
...@@ -57,6 +59,8 @@ def test_oot_registration_for_api_server(): ...@@ -57,6 +59,8 @@ def test_oot_registration_for_api_server():
except OpenAIError as e: except OpenAIError as e:
if "Connection error" in str(e): if "Connection error" in str(e):
time.sleep(3) time.sleep(3)
if time.time() - now > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server did not start in time") from e
else: else:
raise e raise e
server.kill() server.kill()
......
# Separate these tests out from test_completion and test_chat, because they
# require launching a second server with a different flag. Running both servers
# at the same time on a single node will OOM.
import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME
@pytest.fixture(scope="module")
def server_with_return_tokens_as_token_ids_flag(
default_server_args): # noqa: F811
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_completion_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
completion = await client.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: 🎉 is [28705, 31862] for the
# Zephyr tokenizer
prompt="Say 'Hello, world! 🎉'",
echo=True,
temperature=0,
max_tokens=10,
logprobs=1)
text = completion.choices[0].text
token_strs = completion.choices[0].logprobs.tokens
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Check that the token representations are consistent between raw tokens
# and top_logprobs
# Slice off the first one, because there's no scoring associated with BOS
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
top_logprob_keys = [
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
]
assert token_strs[1:] == top_logprob_keys
# Check that decoding the tokens gives the expected text
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
@pytest.mark.asyncio
async def test_chat_return_tokens_as_token_ids_completion(
server_with_return_tokens_as_token_ids_flag):
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
response = await client.chat.completions.create(
model=MODEL_NAME,
# Include Unicode characters to test for dividing a single
# character across multiple tokens: 🎉 is [28705, 31862] for the
# Zephyr tokenizer
messages=[{
"role": "system",
"content": "You like to respond in only emojis, like 🎉"
}, {
"role": "user",
"content": "Please write some emojis: 🐱🐶🎉"
}],
temperature=0,
max_tokens=8,
logprobs=True)
text = response.choices[0].message.content
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
token_ids = []
for logprob_content in response.choices[0].logprobs.content:
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
import asyncio import asyncio
from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from unittest.mock import MagicMock
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
...@@ -42,3 +47,37 @@ async def _async_serving_chat_init(): ...@@ -42,3 +47,37 @@ async def _async_serving_chat_init():
def test_async_serving_chat_init(): def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init()) serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.chat_template == CHAT_TEMPLATE assert serving_completion.chat_template == CHAT_TEMPLATE
def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLMEngine)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
# AsyncLLMEngine.generate(inputs, sampling_params, ...)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
req.max_tokens = 10
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].max_tokens == 10
...@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing ...@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128 # FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256 HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
] if not is_hip() else [64, 80, 96, 112, 128] ] if not is_hip() else [64, 80, 96, 112, 128]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
...@@ -135,6 +135,8 @@ def test_paged_attention( ...@@ -135,6 +135,8 @@ def test_paged_attention(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float] ...@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [8, 16, 32] BLOCK_SIZES = [8, 16, 32]
# Arbitrary values for testing # Arbitrary values for testing
...@@ -53,6 +53,8 @@ def test_copy_blocks( ...@@ -53,6 +53,8 @@ def test_copy_blocks(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -125,6 +127,8 @@ def test_reshape_and_cache( ...@@ -125,6 +127,8 @@ def test_reshape_and_cache(
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -216,8 +220,6 @@ def test_reshape_and_cache_flash( ...@@ -216,8 +220,6 @@ def test_reshape_and_cache_flash(
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
) -> None: ) -> None:
if kv_cache_dtype == "fp8":
pytest.skip()
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
...@@ -249,15 +251,33 @@ def test_reshape_and_cache_flash( ...@@ -249,15 +251,33 @@ def test_reshape_and_cache_flash(
dtype, dtype,
device=device, device=device,
) )
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0].contiguous(
), value_caches[0].contiguous()
del key_caches
del value_caches
# Clone the KV caches. # Clone the KV caches.
cloned_key_cache = key_cache.clone() if kv_cache_dtype == "fp8":
cloned_value_cache = value_cache.clone() cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
ops.reshape_and_cache_flash(key, value, key_cache, value_cache, ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype) slot_mapping, kv_cache_dtype, k_scale, v_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
# Run the reference implementation. # Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
...@@ -270,8 +290,18 @@ def test_reshape_and_cache_flash( ...@@ -270,8 +290,18 @@ def test_reshape_and_cache_flash(
cloned_key_cache[block_idx, block_offset, :, :] = key[i] cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i] cloned_value_cache[block_idx, block_offset, :, :] = value[i]
assert torch.allclose(key_cache, cloned_key_cache) if kv_cache_dtype == "fp8":
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("direction", COPYING_DIRECTION)
...@@ -300,6 +330,8 @@ def test_swap_blocks( ...@@ -300,6 +330,8 @@ def test_swap_blocks(
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction: if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip() pytest.skip()
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -106,8 +106,8 @@ def cutlass_int8_gemm_helper(m: int, ...@@ -106,8 +106,8 @@ def cutlass_int8_gemm_helper(m: int,
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
@pytest.mark.parametrize("m", [512, 222, 100, 33, 1]) @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
...@@ -119,8 +119,8 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, ...@@ -119,8 +119,8 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias) cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
@pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
......
...@@ -20,6 +20,7 @@ def ref_paged_attn( ...@@ -20,6 +20,7 @@ def ref_paged_attn(
block_tables: torch.Tensor, block_tables: torch.Tensor,
scale: float, scale: float,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs = len(query_lens) num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy() block_tables = block_tables.cpu().numpy()
...@@ -53,6 +54,8 @@ def ref_paged_attn( ...@@ -53,6 +54,8 @@ def ref_paged_attn(
(query_len + sliding_window) + (query_len + sliding_window) +
1).bool().logical_not() 1).bool().logical_not()
mask |= sliding_window_mask mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf")) attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype) attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v) out = torch.einsum("hqk,khd->qhd", attn, v)
...@@ -68,13 +71,15 @@ def ref_paged_attn( ...@@ -68,13 +71,15 @@ def ref_paged_attn(
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
kv_lens: List[int], kv_lens: List[int],
num_heads: Tuple[int, int], num_heads: Tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
soft_cap: Optional[float],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
...@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv( ...@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
causal=True, causal=True,
block_table=block_tables, block_table=block_tables,
cache_seqlens=kv_lens_tensor, cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1) ).squeeze(1)
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
...@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv( ...@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap,
) )
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv( ...@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_varlen_with_paged_kv( def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]], seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int], num_heads: Tuple[int, int],
...@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv( ...@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
sliding_window: Optional[int], sliding_window: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
soft_cap: Optional[float],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
...@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv( ...@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
head_size, head_size,
dtype=dtype) dtype=dtype)
value_cache = torch.randn_like(key_cache) value_cache = torch.randn_like(key_cache)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5
cu_query_lens = torch.tensor([0] + query_lens, cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0, dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32) dtype=torch.int32)
...@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv( ...@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
causal=True, causal=True,
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
) )
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
...@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv( ...@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
sliding_window=sliding_window, sliding_window=sliding_window,
soft_cap=soft_cap,
) )
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
import pytest import pytest
import torch import torch
# ruff: noqa: F401
import vllm._C
from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from vllm._custom_ops import scaled_int8_quant from vllm._custom_ops import scaled_int8_quant
......
...@@ -9,11 +9,14 @@ from tests.quantization.utils import is_quant_method_supported ...@@ -9,11 +9,14 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_make_empty_g_idx, marlin_permute_scales) marlin_permute_scales, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32) pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
...@@ -21,12 +24,14 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( ...@@ -21,12 +24,14 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_weights) marlin_weights)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize) marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp, awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
sort_weights)
ACT_ORDER_OPTS = [False, True] ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True] K_FULL_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [False, True]
MARLIN_K_CHUNKS = [128] MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 128, 256] MARLIN_N_CHUNKS = [64, 128, 256]
...@@ -59,12 +64,13 @@ def rand_data(shape, dtype=torch.float16): ...@@ -59,12 +64,13 @@ def rand_data(shape, dtype=torch.float16):
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors): act_order, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor size_m = m_factor
...@@ -89,11 +95,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, ...@@ -89,11 +95,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
# Quantize (and apply act_order if provided) # Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
group_size, act_order) b_weight, quant_type, group_size, act_order)
# Pack to GPTQ format # Pack to GPTQ format
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" so that group ids are # For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing # increasing
...@@ -102,8 +108,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, ...@@ -102,8 +108,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Pack to Marlin format # Pack to Marlin format
weight_perm = get_weight_perm(num_bits) weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack( marlin_q_w_2 = ops.gptq_marlin_repack(
...@@ -111,7 +118,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, ...@@ -111,7 +118,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
sort_indices, sort_indices,
size_k, size_k,
size_n, size_n,
num_bits, quant_type.size_bits,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -122,10 +129,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, ...@@ -122,10 +129,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors): mnk_factors):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
...@@ -144,22 +152,25 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, ...@@ -144,22 +152,25 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
# Quantize # Quantize
w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits, w_ref, q_w, s, zp = quantize_weights(b_weight,
group_size) quant_type,
group_size,
zero_points=True)
# Pack to AWQ format # Pack to AWQ format
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format # Pack to Marlin format
weight_perm = get_weight_perm(num_bits) weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack( marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq, q_w_awq,
size_k, size_k,
size_n, size_n,
num_bits, quant_type.size_bits,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -170,19 +181,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, ...@@ -170,19 +181,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_gptq_marlin_gemm( def test_gptq_marlin_gemm(
k_chunk, k_chunk,
n_chunk, n_chunk,
num_bits, quant_type,
group_size, group_size,
mnk_factors, mnk_factors,
act_order, act_order,
is_k_full, is_k_full,
use_fp32_reduce,
): ):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
...@@ -203,7 +217,7 @@ def test_gptq_marlin_gemm( ...@@ -203,7 +217,7 @@ def test_gptq_marlin_gemm(
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, num_bits, group_size, act_order) b_weight, quant_type, group_size, act_order)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device) marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
...@@ -218,12 +232,13 @@ def test_gptq_marlin_gemm( ...@@ -218,12 +232,13 @@ def test_gptq_marlin_gemm(
g_idx, g_idx,
sort_indices, sort_indices,
workspace.scratch, workspace.scratch,
num_bits, quant_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[1], b_weight.shape[1],
a_input.shape[1], a_input.shape[1],
is_k_full, is_k_full=is_k_full,
has_zp=False, has_zp=False,
use_fp32_reduce=use_fp32_reduce,
) )
output_ref = torch.matmul(a_input, w_ref) output_ref = torch.matmul(a_input, w_ref)
...@@ -239,10 +254,10 @@ def test_gptq_marlin_gemm( ...@@ -239,10 +254,10 @@ def test_gptq_marlin_gemm(
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors): mnk_factors):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
...@@ -257,7 +272,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, ...@@ -257,7 +272,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, (w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL) GPTQ_MARLIN_24_MAX_PARALLEL)
...@@ -270,7 +285,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, ...@@ -270,7 +285,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
marlin_24_meta, marlin_24_meta,
marlin_24_s, marlin_24_s,
workspace_24.scratch, workspace_24.scratch,
num_bits, quant_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[1], b_weight.shape[1],
a_input.shape[1], a_input.shape[1],
...@@ -362,15 +377,18 @@ def test_fp8_marlin_gemm( ...@@ -362,15 +377,18 @@ def test_fp8_marlin_gemm(
reason="Marlin is not supported on this GPU type.") reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_awq_marlin_gemm( def test_awq_marlin_gemm(
k_chunk, k_chunk,
n_chunk, n_chunk,
num_bits, quant_type,
group_size, group_size,
mnk_factors, mnk_factors,
use_fp32_reduce,
): ):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
...@@ -385,7 +403,7 @@ def test_awq_marlin_gemm( ...@@ -385,7 +403,7 @@ def test_awq_marlin_gemm(
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, num_bits, group_size) b_weight, quant_type, group_size)
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
...@@ -403,12 +421,13 @@ def test_awq_marlin_gemm( ...@@ -403,12 +421,13 @@ def test_awq_marlin_gemm(
g_idx, g_idx,
sort_indices, sort_indices,
workspace.scratch, workspace.scratch,
num_bits, quant_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[1], b_weight.shape[1],
a_input.shape[1], a_input.shape[1],
is_k_full, is_k_full=is_k_full,
has_zp, has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
) )
output_ref = torch.matmul(a_input, w_ref) output_ref = torch.matmul(a_input, w_ref)
...@@ -418,3 +437,64 @@ def test_awq_marlin_gemm( ...@@ -418,3 +437,64 @@ def test_awq_marlin_gemm(
print("max_diff = {}".format(max_diff)) print("max_diff = {}".format(max_diff))
assert max_diff < 0.04 assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_qqq_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
):
int8_traits = torch.iinfo(torch.int8)
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
# Quantize activations
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
torch.float)
q_a = (a_input / s_a).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
# Quantize weights
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
marlin_qqq_quantize(b_weight, num_bits, group_size)
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_MAX_PARALLEL)
output = ops.marlin_qqq_gemm(
q_a,
marlin_qqq_q_w,
s_a,
marlin_qqq_s_channel,
marlin_qqq_s_group,
workspace.scratch,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
)
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
...@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol ...@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE = [True, False] IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 17] # Arbitrary values for testing NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing
......
import gc import gc
from unittest.mock import patch
import pytest import pytest
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.model_executor.layers.ops.sample import ( from vllm.model_executor.layers.ops.sample import (_sample_triton,
MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits, _uniform_to_exponential,
sample) sample)
from vllm.model_executor.sampling_metadata import SamplingTensors from vllm.model_executor.sampling_metadata import SamplingTensors
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.triton_utils.libentry import LibEntry
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
get_num_triton_sampler_splits)
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
...@@ -75,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of, ...@@ -75,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
seeds = torch.randint(1, seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, bs), torch.iinfo(torch.long).max, (n_splits, bs),
device="cuda").mul_(random_sampling_mask) device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( #The current _sample_triton does not utilize the
probs=probs, # libentry decoration. The purpose of adding this patch is to test
logprobs=logprobs, # the correctness of libentry.
sample_indices=sample_indices, with patch("vllm.model_executor.layers.ops.sample._sample_triton",
seeds=seeds, LibEntry(_sample_triton)):
max_best_of=max_best_of, sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
modify_greedy_probs=modify_greedy_probs, probs=probs,
save_logprobs=save_logprobs, logprobs=logprobs,
_save_modified_probs=True) sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
assert sampled_tokens.shape == (bs, max_best_of) assert sampled_tokens.shape == (bs, max_best_of)
for i in range(bs): for i in range(bs):
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
...@@ -129,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of, ...@@ -129,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
def test_sample_prompt_logprobs(random_sampling, max_best_of, def test_sample_prompt_logprobs(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size): modify_greedy_probs, seed, vocab_size):
set_random_seed(seed) set_random_seed(seed)
prompt_sizes = [16, 32, 64, 128] * 2 prompt_sizes = [16, 32, 64, 128] * 2
samples = 8 samples = 8
...@@ -156,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of, ...@@ -156,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
seeds = torch.randint(1, seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, samples), torch.iinfo(torch.long).max, (n_splits, samples),
device="cuda").mul_(random_sampling_mask) device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, _ = sample( #ditto
probs=probs, with patch("vllm.model_executor.layers.ops.sample._sample_triton",
logprobs=logprobs, LibEntry(_sample_triton)):
sample_indices=sample_indices, sampled_tokens, sampled_logprobs, _ = sample(
seeds=seeds, probs=probs,
max_best_of=max_best_of, logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs, sample_indices=sample_indices,
save_logprobs=True) seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
assert sampled_tokens.shape == (samples, max_best_of) assert sampled_tokens.shape == (samples, max_best_of)
assert sampled_logprobs.shape == (samples, max_best_of) assert sampled_logprobs.shape == (samples, max_best_of)
for i, t in enumerate(sample_indices): for i, t in enumerate(sample_indices):
......
...@@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files): ...@@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files):
expected_lora_output = [ expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n", "more important than knowledge.\nAuthor: Albert Einstein\n",
"everyone else is already taken.\nAuthor: Oscar Wilde\n", "everyone else is already taken.\nAuthor: Oscar Wilde\n",
"so little time\nAuthor: Frank Zappa\n", "so little time.\nAuthor: Frank Zappa\n",
] ]
output1 = do_sample(llm, gemma_lora_files, lora_id=1) output1 = do_sample(llm, gemma_lora_files, lora_id=1)
......
...@@ -22,14 +22,17 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ...@@ -22,14 +22,17 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora, MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora, QKVParallelLinearWithLora,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) VocabParallelEmbeddingWithLoRA)
# yapf: enable # yapf: enable
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights, convert_mapping) PackedLoRALayerWeights)
from vllm.lora.punica import PunicaWrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
...@@ -47,6 +50,9 @@ TOLERANCES = { ...@@ -47,6 +50,9 @@ TOLERANCES = {
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
# We will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
def get_random_id_to_index(num_loras: int, def get_random_id_to_index(num_loras: int,
...@@ -182,10 +188,12 @@ def create_random_inputs( ...@@ -182,10 +188,12 @@ def create_random_inputs(
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: @pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -204,7 +212,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -204,7 +212,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
embedding, lora_embedding = create_random_embedding_layer() embedding, lora_embedding = create_random_embedding_layer()
lora_embedding.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
layer=lora_embedding, layer=lora_embedding,
...@@ -217,12 +225,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -217,12 +225,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info)
lora_result = lora_embedding(torch.cat(inputs)) lora_result = lora_embedding(torch.cat(inputs))
...@@ -255,12 +263,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -255,12 +263,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(inputs)) lora_result = lora_embedding(torch.cat(inputs))
expected_result = embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs))
...@@ -278,11 +286,13 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: ...@@ -278,11 +286,13 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device, def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size) -> None: vocab_size, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -318,6 +328,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -318,6 +328,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
generate_embeddings_tensor=256, generate_embeddings_tensor=256,
) )
lora_embedding.set_mapping(punica_wrapper)
# All embeddings tensors have the same shape. # All embeddings tensors have the same shape.
embeddings_tensors = [ embeddings_tensors = [
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
...@@ -334,8 +345,12 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -334,8 +345,12 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size,
lora_config.lora_extra_vocab_size)
original_inputs = deepcopy(inputs) original_inputs = deepcopy(inputs)
# Force some of the inputs to be in the extended embeddings range # Force some of the inputs to be in the extended embeddings range
...@@ -349,11 +364,6 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -349,11 +364,6 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
(embedding_id + 1) * embeddings_tensor_len - 1) (embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = vocab_size + embeddings_tensor_len - 1 original_input_[-2] = vocab_size + embeddings_tensor_len - 1
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
expanded_embedding.weight[vocab_size:vocab_size + expanded_embedding.weight[vocab_size:vocab_size +
(embeddings_tensor_len * (embeddings_tensor_len *
max_loras)] = torch.cat(embeddings_tensors) max_loras)] = torch.cat(embeddings_tensors)
...@@ -390,15 +400,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -390,15 +400,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
input_size=(200, ), input_size=(200, ),
input_range=(1, vocab_size), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
original_inputs = deepcopy(inputs) original_inputs = deepcopy(inputs)
lora_mapping = LoRAMapping(index_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(original_inputs)) lora_result = lora_embedding(torch.cat(original_inputs))
expected_result = expanded_embedding(torch.cat(inputs)) expected_result = expanded_embedding(torch.cat(inputs))
...@@ -413,11 +421,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -413,11 +421,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_lm_head_logits_processor(dist_init, num_loras, device, @pytest.mark.parametrize("stage", STAGES)
vocab_size) -> None: def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = PunicaWrapper(8192, 256, device)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
...@@ -443,7 +453,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -443,7 +453,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, logits_processor, lora_logits_processor = _pretest() linear, logits_processor, lora_logits_processor = _pretest()
lora_logits_processor.set_mapping(punica_wrapper)
# NOTE: all the generated loras share the same embeddings tensor. # NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
...@@ -461,17 +471,17 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -461,17 +471,17 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
input_ = torch.rand(20, 1024) is_prefill=stage)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
vocab_size, vocab_size,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_logits_processor.set_mapping(*mapping_info, ) input_ = torch.rand(20, 1024)
lora_result = lora_logits_processor._get_logits( lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs), hidden_states=torch.cat(inputs),
...@@ -510,12 +520,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -510,12 +520,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, is_prefill=stage)
vocab_size, punica_wrapper.update_metadata(
lora_config.lora_extra_vocab_size) lora_mapping,
lora_logits_processor.set_mapping(*mapping_info, ) id_to_index,
max_loras,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_logits_processor._get_logits( lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs), hidden_states=torch.cat(inputs),
...@@ -533,15 +547,118 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, ...@@ -533,15 +547,118 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
atol=atol) atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
def create_random_linear_replicated_layer():
linear = ReplicatedLinear(4096,
4096,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ReplicatedLinearWithLoRA(linear)
lora_linear.create_lora_weights(max_loras, lora_config)
return linear, lora_linear
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer()
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_linear,
layer_weights=linear.weight,
)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device) -> None: device, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -575,7 +692,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -575,7 +692,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer() linear, lora_linear = create_random_linear_parallel_layer()
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
layer=lora_linear, layer=lora_linear,
...@@ -589,16 +706,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -589,16 +706,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
mapping_info = convert_mapping( is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
512, 512,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_linear.set_mapping(*mapping_info, )
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
...@@ -628,11 +745,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -628,11 +745,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size) 512, lora_config.lora_extra_vocab_size)
lora_linear.set_mapping(*mapping_info, )
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0]
...@@ -649,10 +767,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -649,10 +767,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
@pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device) -> None: device, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -707,7 +827,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -707,7 +827,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_column_parallel_packed_layer() linear, lora_linear = create_column_parallel_packed_layer()
lora_linear.set_mapping(punica_wrapper)
lora_dict, sublora_dict = populate_loras( lora_dict, sublora_dict = populate_loras(
id_to_index, id_to_index,
layer=lora_linear, layer=lora_linear,
...@@ -722,16 +842,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -722,16 +842,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
512, 512,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_linear.set_mapping(*mapping_info)
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
...@@ -762,16 +883,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -762,16 +883,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_range=(0, 1), input_range=(0, 1),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
512, 512,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_linear.set_mapping(*mapping_info) # lora_linear.set_mapping(*mapping_info)
lora_result = lora_linear(torch.cat(inputs))[0] lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0]
...@@ -803,7 +926,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -803,7 +926,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
...@@ -825,6 +948,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -825,6 +948,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
is_neox_style, is_neox_style,
) )
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
lora_rope.set_mapping(punica_wrapper)
lora_rope.create_lora_weights(max_loras, lora_config) lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base, linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, { is_neox_style, {
...@@ -840,6 +964,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -840,6 +964,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
input_range=(0, lora_config.lora_extra_vocab_size), input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16, input_type=torch.float16,
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors), long_lora_context = LongContextLoRAContext(list(scaling_factors),
rotary_dim) rotary_dim)
...@@ -854,7 +979,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -854,7 +979,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
for i in range(len(scaling_factors)): for i in range(len(scaling_factors)):
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
scaling_factors[i], 0) scaling_factors[i], 0)
mapping_info = convert_mapping( punica_wrapper.update_metadata(
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
...@@ -862,7 +987,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -862,7 +987,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
long_lora_context=long_lora_context, long_lora_context=long_lora_context,
) )
lora_rope.set_mapping(*mapping_info) # lora_rope.set_mapping(*mapping_info)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query = torch.randn(batch_size,
......
import pytest
import torch
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice
from .utils import DummyLoRAManager
TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
QKV_TENSOR_SIZES = [
(8192, 1024, 1024),
(8192 // 8, 1024 // 8, 1024 // 8),
(4096, 4096, 4096),
(4096 // 2, 4096 // 2, 4096 // 2),
]
BATCH_SIZES = [8, 32, 256]
RANKS = [8]
DTYPES = [torch.float16]
TOLERANCES = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora(m, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
module_name = "module"
weight = torch.rand([m, n], device="cuda", dtype=dtype)
manager.init_random_lora(module_name, weight, rank=rank)
lora = manager.get_module_lora(module_name)
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling
lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="cuda",
dtype=dtype)
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="cuda",
dtype=dtype)
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T
output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora(
input, lora_a_stack, lora_b_stack,
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"),
output)
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="cuda"), output)
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
if m % 2 != 0:
pytest.skip("m must be divisible by 2")
if m // 2 not in TENSOR_SIZES:
pytest.skip("m//2 must be in TENSOR_SIZES")
manager = DummyLoRAManager()
module_name = "module"
weight = torch.rand([m // 2, n], device="cuda", dtype=dtype)
manager.init_random_lora(module_name + "1", weight, rank=rank)
lora_1 = manager.get_module_lora(module_name + "1")
manager.init_random_lora(module_name + "2", weight, rank=rank)
lora_2 = manager.get_module_lora(module_name + "2")
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
],
dim=1)
lora_a_stacks = [
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T
output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (m // 2, m // 2))
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
module_name = "module"
weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype)
weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype)
manager.init_random_lora(module_name + "q", weight_q, rank=rank)
lora_q = manager.get_module_lora(module_name + "q")
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
lora_k = manager.get_module_lora(module_name + "k")
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
lora_v = manager.get_module_lora(module_name + "v")
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
],
dim=1)
lora_a_stacks = [
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
lora_a_stacks[2][i][0] = lora_v.lora_a.T
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T
output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (qkv[0], qkv[1], qkv[2]))
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
# Based on code from https://github.com/punica-ai/punica
import pytest
import torch
import vllm.lora.punica as punica
def assert_close(a, b):
rtol, atol = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
torch.float32: (None, None),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def _lora_ref_impl(
y_final: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
y_stage_1 = torch.empty(
(x.size(0), wa_T_all.size(-2)),
dtype=torch.float32,
device=x.device,
)
bs = x.shape[0]
s = torch.tensor(scale, dtype=torch.float32, device=x.device)
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
xi = x[i].unsqueeze(0).to(torch.float32)
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
if wb_T_all is not None:
wb = wb_T_all[lora_idx, layer_idx].transpose(-1,
-2).to(torch.float32)
tmp = xi @ wa
y_stage_1[i] = tmp.squeeze(0)
y_final[i] += ((tmp @ wb).squeeze(0) *
s if wb_T_all is not None else y_stage_1[i])
return y_final, y_stage_1
H1 = H2 = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
H2 = [64] + H2
R = [1, 2, 4]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("r", R)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_a_extra_shapes(dtype_str, h1, r, seed):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
bs = 32
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
wa_T_all = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, r, dtype=dtype, device=device)
y_ref = y.clone()
_lora_ref_impl(
y_ref,
x,
wa_T_all,
None,
indices,
layer_idx,
1.0,
)
y_our = y.clone()
punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0)
assert_close(y_ref, y_our)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_lora_correctness(dtype_str, h1, h2, seed, device):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
torch.set_default_device(device)
wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype)
y_ref = y.clone()
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
y_our = y.clone()
punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx,
scale)
assert_close(y_ref, y_our)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_lora_correctness_slice(dtype_str, h1, h2, seed, device):
if h2 % 3 != 0 or h2 // 3 not in H1:
pytest.skip("h2 must be divisible by 3 and in supported shapes")
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
torch.set_default_device(device)
wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype)
s = h2 // 3
y_ref = y.clone()
_lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale)
y_our = y.clone()
punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale, 0, s)
punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale, s, s)
punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale, s * 2, s)
assert_close(y_ref[:, :s], y_our[:, :s])
assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2])
assert_close(y_ref[:, s * 2:], y_our[:, s * 2:])
"""
This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
import random
from unittest.mock import patch
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
HIDDEN_SIZES = [
128,
256,
512,
896,
1024,
1152,
1216,
1280,
1536,
1664,
2048,
2240,
2304,
2368,
2432,
2560,
2752,
3072,
3328,
3456,
3584,
3712,
4096,
4480,
4608,
4736,
4864,
5120,
5504,
5632,
5888,
6144,
6400,
6848,
6912,
7168,
7424,
8192,
8960,
9216,
9472,
10240,
11008,
11264,
13824,
14336,
14784,
14848,
15360,
18944,
22016,
22528,
24576,
27392,
27648,
29568,
29696,
32000,
32256,
32512,
32768,
33024,
36864,
43264,
49152,
49408,
60544,
60672,
64000,
64256,
102400,
102656,
128000,
128256,
]
#The size of TP
divisibility = [1, 2, 4, 8, 16, 32, 64]
all_hidden_size = []
for div in divisibility:
for hidden_size in HIDDEN_SIZES:
all_hidden_size.append(hidden_size // div)
HIDDEN_SIZES = list(set(all_hidden_size))
BATCHES = [4]
NUM_LORA = [4]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [32]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
scaling,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)
"""
This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
import random
from unittest.mock import patch
import pytest
import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
HIDDEN_SIZES = [3424, 4096, 4097]
BATCHES = [1, 4, 16, 32]
NUM_LORA = [1, 4, 8, 16, 32, 64, 128]
DTYPES = [torch.float16, torch.bfloat16]
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128]
SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_sgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
scaling,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
scaling: float,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1
(
inputs_tensor,
lora_weights,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
op_type,
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
assert_close(our_out_tensor, ref_out_tensor)
@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
inputs_tensor,
lora_weights_lst,
our_outputs,
ref_outputs,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type="expand",
)
slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)
if __name__ == "__main__":
from itertools import product
lst = list(
product(
BATCHES,
NUM_LORA,
MAX_RANKS,
[1.0],
[torch.float16],
["expand"],
SEED,
CUDA_DEVICES,
))
for ele in lst:
test_punica_bgmv(*ele)
print(f"{ele},pass")
...@@ -64,14 +64,16 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size): ...@@ -64,14 +64,16 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
# if torch.cuda.device_count() < tp_size: # if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(model=model.model_path, llm = vllm.LLM(
enable_lora=True, model=model.model_path,
max_num_seqs=16, enable_lora=True,
max_loras=4, max_num_seqs=16,
max_model_len=400, max_loras=4,
tensor_parallel_size=tp_size, max_model_len=400,
quantization=model.quantization, tensor_parallel_size=tp_size,
trust_remote_code=True) gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
if model.quantization is None: if model.quantization is None:
expected_no_lora_output = [ expected_no_lora_output = [
...@@ -156,24 +158,28 @@ def test_quant_model_tp_equality(tinyllama_lora_files, model): ...@@ -156,24 +158,28 @@ def test_quant_model_tp_equality(tinyllama_lora_files, model):
# if torch.cuda.device_count() < 2: # if torch.cuda.device_count() < 2:
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}") # pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
llm_tp1 = vllm.LLM(model=model.model_path, llm_tp1 = vllm.LLM(
enable_lora=True, model=model.model_path,
max_num_seqs=16, enable_lora=True,
max_loras=4, max_num_seqs=16,
tensor_parallel_size=1, max_loras=4,
quantization=model.quantization, tensor_parallel_size=1,
trust_remote_code=True) gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
del llm_tp1 del llm_tp1
cleanup() cleanup()
llm_tp2 = vllm.LLM(model=model.model_path, llm_tp2 = vllm.LLM(
enable_lora=True, model=model.model_path,
max_num_seqs=16, enable_lora=True,
max_loras=4, max_num_seqs=16,
tensor_parallel_size=2, max_loras=4,
quantization=model.quantization) tensor_parallel_size=2,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
del llm_tp2 del llm_tp2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment