Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_chunked_prefill_distributed.py
```
"""
import os
import pytest
import torch
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
......@@ -8,9 +8,9 @@ import pytest
import ray
import torch
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed import (broadcast_tensor_dict,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
......
......@@ -6,9 +6,8 @@ import ray
import torch
import torch.distributed as dist
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
......@@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank,
distributed_init_port)
custom_ar.init_custom_ar()
custom_all_reduce.init_custom_all_reduce()
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_ar.capture():
with custom_all_reduce.capture():
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
......@@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port)
sz = 1024
custom_ar.init_custom_ar()
fa = custom_ar.get_handle()
custom_all_reduce.init_custom_all_reduce()
fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
......
import multiprocessing
import os
import pytest
import torch
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.utils import update_environment_variables
def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
......@@ -26,20 +27,23 @@ def distributed_run(fn, world_size):
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def update_env(fn):
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
def wrapped_fn(env):
update_environment_variables(env)
init_distributed_environment()
fn()
return wrapper
return wrapped_fn
@update_env
@worker_fn_wrapper
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
......@@ -54,7 +58,7 @@ def test_pynccl():
distributed_run(worker_fn, 2)
@update_env
@worker_fn_wrapper
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
......
import multiprocessing
import tempfile
def target_fn(env, filepath):
from vllm.utils import update_environment_variables
update_environment_variables(env)
from vllm.utils import nccl_integrity_check
nccl_integrity_check(filepath)
def test_library_file():
# note: don't import vllm.distributed.device_communicators.pynccl
# before running this test, otherwise the library file will be loaded
# and it might interfere with the test
from vllm.utils import find_nccl_library
so_file = find_nccl_library()
with open(so_file, 'rb') as f:
content = f.read()
try:
# corrupt the library file, should raise an exception
with open(so_file, 'wb') as f:
f.write(content[:len(content) // 2])
p = multiprocessing.Process(target=target_fn, args=({}, so_file))
p.start()
p.join()
assert p.exitcode != 0
# move the library file to a tmp path
# test VLLM_NCCL_SO_PATH
fd, path = tempfile.mkstemp()
with open(path, 'wb') as f:
f.write(content)
p = multiprocessing.Process(target=target_fn,
args=({
"VLLM_NCCL_SO_PATH": path
}, path))
p.start()
p.join()
assert p.exitcode == 0
finally:
with open(so_file, 'wb') as f:
f.write(content)
import random
from unittest.mock import MagicMock
import pytest
from transformers import PreTrainedTokenizer
from tests.core.utils import create_seq_group
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [1, 12])
@pytest.mark.skip_global_cleanup
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
"""Verify multi-step decoding appends token ids correctly.
We append token ids and verify all the token ids were appended correctly.
Note that ignore_eos=True.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=1024,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=seq_output_len +
num_new_tokens,
ignore_eos=True),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
output_processor.process_outputs(seq_group, outputs)
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
@pytest.mark.parametrize("max_tokens", [128 + 3])
@pytest.mark.skip_global_cleanup
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, max_tokens: int):
"""Verify tokens after max_tokens are dropped and not appended to the
sequence.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=max_tokens, ),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)
# Expect the processed sequence to not go over max tokens in len.
assert seq.get_len() == seq_prompt_len + max_tokens
# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens
@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""Verify the eos token id is included in the sequence, but subsequent
tokens are dropped (not appended to sequence).
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
eos_token_id = 100
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens, ),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)
# Expect the processed sequence to not go beyond provided eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:eos_index + 1]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens
@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""When sampling parameters dictate that we should ignore the eos token id,
ensure all token ids are appended even if the eos token id is emitted.
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()
eos_token_id = 100
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)
seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens,
ignore_eos=True,
),
)
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id
outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]
assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)
# Expect the processed sequence to go beyond eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens
def mock_tokenizer(eos_token_id=1000):
tokenizer = MagicMock(spec=PreTrainedTokenizer)
tokenizer.eos_token_id = eos_token_id
return tokenizer
import pytest
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_computed_prefix_blocks(model: str):
# This test checks if the engine generates completions both with and
# without optional detokenization, that detokenization includes text
# and no-detokenization doesn't, and that both completions have the same
# token_ids.
prompt = (
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")
llm = LLM(model=model)
sampling_params = SamplingParams(max_tokens=10,
temperature=0.0,
detokenize=False)
outputs_no_detokenization = llm.generate(prompt,
sampling_params)[0].outputs[0]
sampling_params.detokenize = True
outputs_with_detokenization = llm.generate(prompt,
sampling_params)[0].outputs[0]
assert outputs_no_detokenization.text == ''
assert outputs_with_detokenization.text != ''
assert outputs_no_detokenization.token_ids == \
outputs_with_detokenization.token_ids
import pytest
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(model=model, skip_tokenizer_init=True)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids
......@@ -3,7 +3,7 @@
2. One of the provided stop tokens
3. The EOS token
Run `pytest tests/samplers/test_stop_reason.py`.
Run `pytest tests/engine/test_stop_reason.py`.
"""
import pytest
......
from typing import Any, List, Optional
import pytest
from vllm import CompletionOutput, LLMEngine, SamplingParams
MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200
@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
return vllm_runner(MODEL)
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
# token id 13013 => " organization"
_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)
_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)
def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason
assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
import pytest
import torch
from transformers import AutoTokenizer
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor)
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
TEST_SCHEMA = {
"type": "object",
......@@ -73,3 +76,36 @@ def test_guided_logits_processors():
json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=TEST_REGEX)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=TEST_SCHEMA)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
import pytest
from vllm import LLM, SamplingParams
def test_multiple_sampling_params():
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95),
SamplingParams(temperature=0.7, top_p=0.95),
SamplingParams(temperature=0.99, top_p=0.95),
]
# Multiple SamplingParams should be matched with each prompt
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(prompts) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.generate(prompts, sampling_params=sampling_params[:3])
# Single SamplingParams should be applied to every prompt
single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
outputs = llm.generate(prompts, sampling_params=single_sampling_params)
assert len(prompts) == len(outputs)
# sampling_params is None, default params should be applied
outputs = llm.generate(prompts, sampling_params=None)
assert len(prompts) == len(outputs)
\ No newline at end of file
......@@ -141,7 +141,7 @@ def server(zephyr_lora_files):
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128"
"128",
])
ray.get(server_runner.ready.remote())
yield server_runner
......@@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
......@@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
......@@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
......@@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
......@@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
......@@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
assert json1["age"] != json2["age"]
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
......@@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
......@@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None
......@@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2
......@@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
assert completion.choices[i].text in TEST_CHOICE
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
......@@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE
......@@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice1 != choice2
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42))
extra_body=dict(guided_json=42,
guided_decoding_backend=guided_decoding_backend))
messages = [{
"role": "system",
......@@ -692,20 +723,51 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={"type": "json_object"})
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
# -9999.0 is the minimum logprob returned by OpenAI
assert all(
isinstance(logprob, float) and logprob >= -9999.0
for token_dict in top_logprobs
for token, logprob in token_dict.items())
content = resp.choices[0].message.content
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={"type": "json_object"})
content = resp.choices[0].message.content
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
......@@ -742,5 +804,36 @@ number: "1" | "2"
assert content.strip() == ground_truth
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert (completion.choices[0].text is not None
and re.search(r"^" + prompt_text, completion.choices[0].text))
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
assert len(logprobs.tokens) > 5
if __name__ == "__main__":
pytest.main([__file__])
import multiprocessing
import sys
import time
import torch
from openai import OpenAI, OpenAIError
from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits
def server_function(port):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --dtype"
f" float32 --api-key token-abc123 --port {port}").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
def test_oot_registration_for_api_server():
port = get_open_port()
server = multiprocessing.Process(target=server_function, args=(port, ))
server.start()
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
else:
raise e
server.kill()
generated_text = completion.choices[0].message.content
# make sure only the first token is generated
rest = generated_text.replace("<s>", "")
assert rest == ""
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0230364128947258,
"1": 0.01979283057153225,
"2": 0.0241350457072258,
"3": 0.0308314748108387,
"4": 0.0430733822286129,
"5": 0.0370396226644516,
"6": 0.0306222103536129,
"7": 0.0357491634786129,
"8": 0.0358189195394516,
"9": 0.0443289652466774,
"10": 0.0433175228536129,
"11": 0.0416782945394516,
"12": 0.0366908498108387,
"13": 0.0432477705180645,
"14": 0.0410505048930645,
"15": 0.0457589291036129,
"16": 0.0418526791036129,
"17": 0.0432477705180645,
"18": 0.0469447560608387,
"19": 0.0514787957072258,
"20": 0.0541294664144516,
"21": 0.0587681382894516,
"22": 0.0625,
"23": 0.0585588738322258,
"24": 0.0600237175822258,
"25": 0.0588030144572258,
"26": 0.0531180277466774,
"27": 0.06396484375,
"28": 0.0603027381002903,
"29": 0.0582101047039032,
"30": 0.0625348836183548,
"31": 0.0585588738322258,
"32": 0.0582798570394516,
"33": 0.0575125589966774,
"34": 0.0590820349752903,
"35": 0.0614188089966774,
"36": 0.0631975457072258,
"37": 0.0615931935608387,
"38": 0.0601283498108387,
"39": 0.0571986623108387,
"40": 0.0670340433716774,
"41": 0.0523507259786129,
"42": 0.0547223798930645,
"43": 0.0631975457072258,
"44": 0.0663713738322258,
"45": 0.0603376142680645,
"46": 0.0652204304933548,
"47": 0.0734514519572258,
"48": 0.0693708211183548,
"49": 0.0725446492433548,
"50": 0.0627790242433548,
"51": 0.0691266804933548,
"52": 0.0688825398683548,
"53": 0.068429134786129,
"54": 0.0605119988322258,
"55": 0.0799386203289032,
"56": 0.0853097140789032,
"57": 0.0661969929933548,
"58": 0.0689871683716774,
"59": 0.0724051371216774,
"60": 0.0541643425822258,
"61": 0.0626743882894516,
"62": 0.0628487765789032,
"63": 0.0607212632894516,
"64": 0.0589076466858387,
"65": 0.0451660193502903,
"66": 0.0453055277466774,
"67": 0.0414341539144516,
"68": 0.0385044664144516,
"69": 0.0414341539144516,
"70": 0.0466308631002903,
"71": 0.0399693101644516,
"72": 0.0437011756002903,
"73": 0.0434221550822258,
"74": 0.0428989976644516,
"75": 0.0401785746216774,
"76": 0.0431082621216774,
"77": 0.0484444759786129,
"78": 0.0417829267680645,
"79": 0.0418178029358387
}
}
}
}
\ No newline at end of file
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0152239128947258,
"1": 0.0188860222697258,
"2": 0.0354178324341774,
"3": 0.0376674123108387,
"4": 0.0418526791036129,
"5": 0.0433175228536129,
"6": 0.0397600457072258,
"7": 0.0424455925822258,
"8": 0.0415387861430645,
"9": 0.0408412404358387,
"10": 0.0395856611430645,
"11": 0.0377371683716774,
"12": 0.0400739423930645,
"13": 0.040771484375,
"14": 0.0393415205180645,
"15": 0.0369001142680645,
"16": 0.03857421875,
"17": 0.0387486070394516,
"18": 0.0403180830180645,
"19": 0.0396205373108387,
"20": 0.0375627800822258,
"21": 0.0407366082072258,
"22": 0.0432477705180645,
"23": 0.0377022884786129,
"24": 0.0399693101644516,
"25": 0.0374581478536129,
"26": 0.0413295216858387,
"27": 0.0442243330180645,
"28": 0.0424804724752903,
"29": 0.0456891767680645,
"30": 0.0409109964966774,
"31": 0.0482352152466774
}
}
}
}
......@@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import cache_ops, ops
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
......@@ -33,7 +33,11 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
<<<<<<< HEAD
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] if not is_hip() else ["auto"]
=======
KV_CACHE_DTYPE = ["auto", "fp8"]
>>>>>>> v0.4.1
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
......@@ -173,6 +177,9 @@ def test_paged_attention(
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
......@@ -189,6 +196,7 @@ def test_paged_attention(
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
......@@ -220,12 +228,13 @@ def test_paged_attention(
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
raise AssertionError(f"Unknown version: {version}")
# Run the reference implementation.
if kv_cache_dtype == "fp8_e5m2":
if kv_cache_dtype == "fp8":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
......@@ -233,14 +242,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)
......@@ -264,7 +273,8 @@ def test_paged_attention(
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
if kv_cache_dtype == "fp8_e5m2":
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
......
......@@ -4,7 +4,8 @@ from typing import Tuple
import pytest
import torch
from vllm._C import cache_ops
from vllm import _custom_ops as ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
......@@ -23,7 +24,11 @@ SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
<<<<<<< HEAD
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] if not is_hip() else ["auto"]
=======
KV_CACHE_DTYPE = ["auto", "fp8"]
>>>>>>> v0.4.1
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
......@@ -79,7 +84,7 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
ops.copy_blocks(key_caches, value_caches, block_mapping)
# Run the reference implementation.
for src, dsts in block_mapping.items():
......@@ -105,6 +110,7 @@ def test_copy_blocks(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache(
kv_cache_factory,
......@@ -116,7 +122,10 @@ def test_reshape_and_cache(
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
......@@ -132,17 +141,33 @@ def test_reshape_and_cache(
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype,
None, seed, device)
num_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(key_cache, cloned_key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(value_cache, cloned_value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
# Using default kv_scale
kv_scale = 1.0
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, "auto")
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, kv_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(key_cache, result_key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(value_cache, result_value_cache)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
......@@ -156,8 +181,18 @@ def test_reshape_and_cache(
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
......@@ -169,6 +204,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks(
kv_cache_factory,
......@@ -181,7 +217,12 @@ def test_swap_blocks(
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
......@@ -202,24 +243,60 @@ def test_swap_blocks(
# Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
src_device)
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
seed, src_device)
# Create the KV caches on the second device.
dist_key_caches, dist_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
dst_device)
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
seed, dst_device)
src_key_caches_clone = src_key_caches[0].clone()
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
for src, dst in block_mapping.items():
assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache, cache_fp8)
converted_cache = torch.empty_like(cache)
ops.convert_fp8(cache_fp8, converted_cache)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
......@@ -5,7 +5,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
8199] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0]
CUDA_DEVICES = [
......
......@@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
).cuda()
# Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment