Commit f48954a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.5.0

parents 1dba29d3 8f89d720
......@@ -2,10 +2,8 @@
Run `pytest tests/samplers/test_beam_search.py`.
"""
import gc
import pytest
import torch
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
......@@ -30,19 +28,13 @@ def test_beam_search_single_input(
beam_width: int,
) -> None:
example_prompts = example_prompts[:1]
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del vllm_model
# NOTE(woosuk): For some reason, the following GC is required to avoid
# GPU OOM errors in the following tests using `vllm_runner`.
gc.collect()
torch.cuda.empty_cache()
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
......
......@@ -7,25 +7,27 @@ import pytest
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
# We also test with llama because it has generation_config to specify EOS
# (past regression).
MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [1024])
def test_beam_search_single_input(
@pytest.mark.parametrize("max_tokens", [512])
def test_ignore_eos(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
example_prompts = "1 + 1 is"
with vllm_runner(model, dtype=dtype) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
vllm_model = vllm_runner(model, dtype=dtype)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
ignore_eos_output = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params)
print(len(ignore_eos_output[0].outputs[0].token_ids))
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0
for prompt in example_prompts:
ignore_eos_output = vllm_model.model.generate(
prompt, sampling_params=sampling_params)
output_length = len(ignore_eos_output[0].outputs[0].token_ids)
assert output_length == max_tokens
......@@ -14,46 +14,46 @@ def test_logits_processor_force_generate(
model: str,
dtype: str,
) -> None:
vllm_model = vllm_runner(model, dtype=dtype)
tokenizer = vllm_model.model.get_tokenizer()
repeat_times = 2
enforced_answers = " vLLM"
vllm_token_ids = tokenizer.encode(enforced_answers,
add_special_tokens=False)
max_tokens = len(vllm_token_ids) * repeat_times
def pick_vllm(token_ids, logits):
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
logits[token_id] = torch.finfo(logits.dtype).max
return logits
params_with_logprobs = SamplingParams(
logits_processors=[pick_vllm],
prompt_logprobs=3,
max_tokens=max_tokens,
)
# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
example_prompts[0],
params=params_with_logprobs,
)
# test prompt_logprobs is not None
vllm_model.model._add_request(
example_prompts[1],
params=SamplingParams(
with vllm_runner(model, dtype=dtype) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
repeat_times = 2
enforced_answers = " vLLM"
vllm_token_ids = tokenizer.encode(enforced_answers,
add_special_tokens=False)
max_tokens = len(vllm_token_ids) * repeat_times
def pick_vllm(token_ids, logits):
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
logits[token_id] = torch.finfo(logits.dtype).max
return logits
params_with_logprobs = SamplingParams(
logits_processors=[pick_vllm],
prompt_logprobs=3,
max_tokens=max_tokens,
),
)
# test grouped requests
vllm_model.model._add_request(
example_prompts[2],
params=SamplingParams(max_tokens=max_tokens),
)
outputs = vllm_model.model._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
)
# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
example_prompts[0],
params=params_with_logprobs,
)
# test prompt_logprobs is not None
vllm_model.model._add_request(
example_prompts[1],
params=SamplingParams(
prompt_logprobs=3,
max_tokens=max_tokens,
),
)
# test grouped requests
vllm_model.model._add_request(
example_prompts[2],
params=SamplingParams(max_tokens=max_tokens),
)
outputs = vllm_model.model._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
......@@ -12,6 +12,7 @@ MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
......@@ -19,6 +20,7 @@ def test_get_prompt_logprobs(
dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
detokenize: bool,
example_prompts,
):
max_num_seqs = 256
......@@ -30,27 +32,27 @@ def test_get_prompt_logprobs(
max_num_batched_tokens = chunked_prefill_token_size
max_tokens = 5
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
max_logprobs=num_top_logprobs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_top_logprobs,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
with hf_runner(model, dtype=dtype) as hf_model:
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
with vllm_runner(
model,
dtype=dtype,
max_logprobs=num_top_logprobs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
) as vllm_model:
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_top_logprobs,
temperature=0.0,
detokenize=detokenize)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
# Test whether logprobs are included in the results.
for result in vllm_results:
......@@ -65,11 +67,16 @@ def test_get_prompt_logprobs(
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
if detokenize:
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
else:
assert output_text == ''
assert output_string_from_most_likely_tokens == [None] * max_tokens
# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
......@@ -98,9 +105,10 @@ def test_get_prompt_logprobs(
hf_logprob[i][-1][token_id].item(),
atol=1e-1,
rtol=1e-1)
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned "
" to the user.")
if detokenize:
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned"
" to the user.")
# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
......
......@@ -17,16 +17,27 @@ def test_ranks(
num_top_logprobs = 5
num_prompt_logprobs = 5
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
## Test greedy logprobs ranks
vllm_sampling_params = SamplingParams(temperature=0.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
vllm_sampling_params)
with vllm_runner(model, dtype=dtype,
max_logprobs=num_top_logprobs) as vllm_model:
## Test greedy logprobs ranks
vllm_sampling_params = SamplingParams(
temperature=0.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
vllm_sampling_params)
## Test non-greedy logprobs ranks
sampling_params = SamplingParams(temperature=1.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
for result in vllm_results:
assert result[2] is not None
assert len(result[2]) == len(result[0])
......@@ -35,13 +46,6 @@ def test_ranks(
assert token in logprobs
assert logprobs[token].rank == 1
## Test non-greedy logprobs ranks
sampling_params = SamplingParams(temperature=1.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
for result in res:
assert result[2] is not None
assert len(result[2]) == len(result[0])
......
......@@ -17,9 +17,8 @@ RANDOM_SEEDS = list(range(5))
@pytest.fixture
def vllm_model(vllm_runner):
vllm_model = vllm_runner(MODEL, dtype="half")
yield vllm_model
del vllm_model
with vllm_runner(MODEL, dtype="half") as vllm_model:
yield vllm_model
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
......
......@@ -18,9 +18,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalData
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, MultiModalData
from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid
......@@ -76,7 +77,11 @@ class AsyncLLM:
swap_space=swap_space,
enforce_eager=enforce_eager,
max_seq_len_to_capture=max_seq_len_to_capture,
# For now use ray for the distributed back-end, since
# we rely on the use of engine_use_ray=True to avoid
# reinitializing CUDA in the same process (driver worker)
engine_use_ray=True,
distributed_executor_backend="ray",
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
......
......@@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_proposals(
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
......
......@@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len():
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
......@@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations():
k,
prompt_len=prompt_len)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
......@@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k():
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
......
......@@ -50,9 +50,10 @@ def test_ngram_algo_correctness_for_single_no_match():
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
......@@ -117,9 +118,10 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
......@@ -188,9 +190,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
......
import gc
import json
import os
import subprocess
......@@ -7,7 +6,6 @@ from unittest.mock import MagicMock, patch
import openai
import pytest
import ray
import torch
from vllm import SamplingParams
# yapf: disable
......@@ -71,72 +69,66 @@ def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
loaded_hf_model = vllm_runner(model_ref,
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1,
s3_endpoint="object.ord1.coreweave.com",
))
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501
assert deserialized_outputs
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
outputs = vllm_model.generate(prompts, sampling_params)
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
serialize_vllm_model(vllm_model.model.llm_engine,
config_for_serializing,
encryption_key_path=key_path)
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
gc.collect()
torch.cuda.empty_cache()
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
serialize_vllm_model(vllm_model.model.llm_engine,
config_for_serializing,
encryption_key_path=key_path)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
loaded_vllm_model = vllm_runner(
with vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=config_for_deserializing)
model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
assert outputs == deserialized_outputs
assert outputs == deserialized_outputs
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
tmp_path):
hf_model = hf_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(hf_model.model)
del hf_model
gc.collect()
torch.cuda.empty_cache()
loaded_hf_model = vllm_runner(model_ref,
with hf_runner(model_ref) as hf_model:
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(hf_model.model)
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
num_readers=1,
))
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
assert outputs == deserialized_outputs
assert outputs == deserialized_outputs
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
......@@ -150,16 +142,13 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
test_prompts = create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters
vllm_model = vllm_runner(model_ref, )
model_path = tmp_path / (model_ref + ".tensors")
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
del vllm_model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(
with vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
......@@ -172,10 +161,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
max_cpu_loras=2,
max_num_seqs=50,
max_model_len=1000,
)
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
) as loaded_vllm_model:
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
assert loaded_vllm_model
assert loaded_vllm_model
def test_load_without_tensorizer_load_format(vllm_runner):
......@@ -188,19 +177,15 @@ def test_load_without_tensorizer_load_format(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
## Serialize model
vllm_model = vllm_runner(model_ref, )
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
model_loader_extra_config = {
"tensorizer_uri": str(model_path),
}
serialize_vllm_model(vllm_model.model.llm_engine,
TensorizerConfig(tensorizer_uri=model_path))
del vllm_model
gc.collect()
torch.cuda.empty_cache()
model_loader_extra_config = {
"tensorizer_uri": str(model_path),
}
## Start OpenAI API server
openai_args = [
......@@ -224,9 +209,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
......@@ -262,18 +246,15 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path))
vllm_model = vllm_runner(model_ref)
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config)
with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config)
assert is_vllm_tensorized(config)
del vllm_model
gc.collect()
torch.cuda.empty_cache()
assert is_vllm_tensorized(config)
loaded_vllm_model = vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=config)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=config) as loaded_vllm_model:
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501
assert outputs == deserialized_outputs
assert outputs == deserialized_outputs
......@@ -63,8 +63,9 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
def test_rope_scaling():
def test_rope_customization():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
llama_model_config = ModelConfig(
......@@ -76,6 +77,7 @@ def test_rope_scaling():
seed=0,
)
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
assert llama_model_config.max_model_len == 8192
llama_model_config = ModelConfig(
......@@ -86,9 +88,12 @@ def test_rope_scaling():
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
rope_theta=TEST_ROPE_THETA,
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert getattr(llama_model_config.hf_config, "rope_theta",
None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384
longchat_model_config = ModelConfig(
......
......@@ -53,6 +53,27 @@ def test_gc():
assert allocated < 50 * 1024 * 1024
def test_model_from_modelscope(monkeypatch):
# model: https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary
MODELSCOPE_MODEL_NAME = "qwen/Qwen1.5-0.5B-Chat"
monkeypatch.setenv("VLLM_USE_MODELSCOPE", "True")
try:
llm = LLM(model=MODELSCOPE_MODEL_NAME)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
assert len(outputs) == 4
finally:
monkeypatch.delenv("VLLM_USE_MODELSCOPE", raising=False)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
import multiprocessing as mp
import os
import shutil
from tempfile import TemporaryDirectory
......@@ -18,9 +19,7 @@ prompts = [
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
seed=0,
temperature=0,
max_tokens=256,
ignore_eos=True,
)
......@@ -40,51 +39,89 @@ def test_filter_subtensors():
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
for key, tensor in filtered_state_dict.items():
assert tensor.equal(state_dict[key])
# NOTE: don't use `euqal` here, as the tensor might contain NaNs
assert tensor is state_dict[key]
@pytest.fixture(scope="module")
def llama_2_7b_files():
with TemporaryDirectory() as cache_dir:
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
cache_dir=cache_dir,
ignore_patterns="*.bin*")
yield input_dir
def _run_writer(input_dir, output_dir, weights_patterns, **kwargs):
llm_sharded_writer = LLM(model=input_dir, **kwargs)
# Dump worker states to output directory
llm_sharded_writer.llm_engine.model_executor.save_sharded_state(
path=output_dir)
# Copy metadata files to output directory
for file in os.listdir(input_dir):
if not any(file.endswith(ext) for ext in weights_patterns):
shutil.copy(f"{input_dir}/{file}", output_dir)
def _run_generate(input_dir, queue: mp.Queue, **kwargs):
llm = LLM(model=input_dir, **kwargs)
gen = llm.generate(prompts, sampling_params)
queue.put([g.outputs[0].__dict__ for g in gen])
queue.close()
queue.join_thread()
@pytest.mark.parametrize("enable_lora", [False, True])
def test_sharded_state_loader(enable_lora):
weights_patterns = ("*.bin", "*.pt", "*.safetensors")
@pytest.mark.parametrize("tp_size", [1, 2])
def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
llama_2_7b_files):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir:
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
cache_dir=cache_dir)
llm = LLM(
model=input_dir,
worker_use_ray=True,
gpu_memory_utilization=0.3,
)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=output_dir)
# Copy metadata files to output directory
for file in os.listdir(input_dir):
if not any(file.endswith(ext) for ext in weights_patterns):
shutil.copy(f"{input_dir}/{file}", output_dir)
del llm.llm_engine.model_executor
llm_before = LLM(
model=input_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
)
gen_before = llm_before.generate(prompts, sampling_params)
out_before = [gen.outputs[0].__dict__ for gen in gen_before]
del llm_before.llm_engine.model_executor
llm_after = LLM(
model=output_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
load_format="sharded_state",
)
gen_after = llm_after.generate(prompts, sampling_params)
out_after = [gen.outputs[0].__dict__ for gen in gen_after]
del llm_after.llm_engine.model_executor
weights_patterns = ("*.safetensors", )
gpu_memory_utilization = 0.8
input_dir = llama_2_7b_files
ctx = mp.get_context("spawn")
# Run in separate processes for memory & CUDA isolation
with TemporaryDirectory() as output_dir:
p = ctx.Process(target=_run_writer,
args=(input_dir, output_dir, weights_patterns),
kwargs=dict(
tensor_parallel_size=tp_size,
distributed_executor_backend="mp",
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=True,
))
p.start()
p.join()
queue = ctx.Queue()
p = ctx.Process(target=_run_generate,
args=(input_dir, queue),
kwargs=dict(
distributed_executor_backend="mp",
enable_lora=enable_lora,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tp_size,
))
p.start()
p.join()
out_before = queue.get()
p = ctx.Process(target=_run_generate,
args=(output_dir, queue),
kwargs=dict(
distributed_executor_backend="mp",
enable_lora=enable_lora,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tp_size,
load_format="sharded_state",
))
p.start()
p.join()
out_after = queue.get()
assert out_before == out_after
import asyncio
import os
import socket
import sys
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)
import pytest
from vllm.utils import deprecate_kwargs, merge_async_iterators
from vllm.utils import deprecate_kwargs, get_open_port, merge_async_iterators
from .utils import error_on_warning
......@@ -116,3 +118,15 @@ def test_deprecate_kwargs_additional_message():
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)
def test_get_open_port():
os.environ["VLLM_PORT"] = "5678"
# make sure we can get multiple ports, even if the env var is set
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
s1.bind(("localhost", get_open_port()))
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
s2.bind(("localhost", get_open_port()))
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
s3.bind(("localhost", get_open_port()))
os.environ.pop("VLLM_PORT")
import pytest
from transformers.image_processing_utils import BaseImageProcessor
from vllm.transformers_utils.image_processor import get_image_processor
IMAGE_PROCESSOR_NAMES = [
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-34b-hf",
]
@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES)
def test_image_processor_revision(processor_name: str):
# Assume that "main" branch always exists
image_processor = get_image_processor(processor_name, revision="main")
assert isinstance(image_processor, BaseImageProcessor)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match='not a valid git identifier'):
get_image_processor(processor_name, revision="never")
......@@ -24,7 +24,8 @@ class ServerRunner:
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
......
......@@ -13,7 +13,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__
__version__ = "0.4.3"
__version__ = "0.5.0"
__all__ = [
"LLM",
......
from typing import Optional, Tuple, Type
import contextlib
from typing import List, Optional, Tuple, Type
import torch
try:
from vllm._C import cache_ops as vllm_cache_ops
from vllm._C import ops as vllm_ops
except ImportError:
pass
import vllm._C
except ImportError as e:
from vllm.logger import init_logger
logger = init_logger(__name__)
logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
import vllm._moe_C
with contextlib.suppress(ImportError):
# ruff: noqa: F401
import vllm._punica_C
def is_custom_op_supported(op_name: str) -> bool:
op, overloads = torch._C._jit_get_operation(op_name)
return op is not None
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.silu_and_mul(out, x)
torch.ops._C.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_and_mul(out, x)
torch.ops._C.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_tanh_and_mul(out, x)
torch.ops._C.gelu_tanh_and_mul(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_fast(out, x)
torch.ops._C.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_new(out, x)
torch.ops._C.gelu_new(out, x)
# page attention ops
......@@ -51,7 +65,7 @@ def paged_attention_v1(
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
vllm_ops.paged_attention_v1(
torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
......@@ -81,7 +95,7 @@ def paged_attention_v2(
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
vllm_ops.paged_attention_v2(
torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
......@@ -98,8 +112,8 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
is_neox)
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
......@@ -107,20 +121,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
vllm_ops.rms_norm(out, input, weight, epsilon)
torch.ops._C.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# quantization ops
......@@ -128,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
thy)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# gptq
......@@ -142,27 +156,27 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
# marlin_24
......@@ -170,14 +184,14 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m, size_n,
size_k)
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m,
size_n, size_k)
# cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
......@@ -186,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
return out
......@@ -196,21 +210,22 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes)
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
......@@ -218,9 +233,9 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
# fp8
......@@ -257,28 +272,40 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# if scale is None:
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# vllm_ops.static_scaled_fp8_quant(output, input, scale)
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
def scaled_int8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize the input tensor to int8 and return the quantized tensor.
Quantize the input tensor to int8 and return the quantized tensor and scale.
Args:
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
Returns:
torch.Tensor: Output tensor in int8.
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
# static-per-tensor quantization.
torch.ops._C.static_scaled_int8_quant(output, input, scale)
return output, scale
# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales
# moe
......@@ -286,9 +313,16 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
gating_output: float) -> None:
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
token_expert_indicies, gating_output)
def reshape_and_cache(
......@@ -300,8 +334,9 @@ def reshape_and_cache(
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, kv_scale)
def reshape_and_cache_flash(
......@@ -312,25 +347,115 @@ def reshape_and_cache_flash(
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor,
input: torch.Tensor,
scale: float = 1.0,
kv_dtype: str = "fp8") -> None:
vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype)
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# ruff: noqa: E501
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
device)
# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
handles: List[str], offsets: List[int], rank: int,
full_nvlink: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
offsets, rank, full_nvlink)
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
full_nvlink: bool) -> bool:
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
full_nvlink)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
#TODO: cuda_utils, custom_ar
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
offsets: List[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: List[str],
offsets: List[List[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
# punica
def dispatch_bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
) -> None:
torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
scale)
def dispatch_bgmv_low_level(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
h_in: int,
h_out: int,
y_offset: int,
) -> None:
torch.ops._punica_C.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
h_in,
h_out,
y_offset,
)
......@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm._C import cache_ops
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
......@@ -47,11 +47,11 @@ class FlashAttentionBackend(AttentionBackend):
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
......@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass
......@@ -285,7 +285,7 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
cache_ops.reshape_and_cache_flash(
ops.reshape_and_cache_flash(
key,
value,
key_cache,
......@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out = flash_attn_varlen_func(
flash_attn_varlen_func(
q=query,
k=key,
v=value,
......@@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl):
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
out=output[:num_prefill_tokens],
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func(
flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
......@@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl):
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
out=output[:num_prefill_tokens],
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
......@@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
).squeeze(1)
out=output[num_prefill_tokens:].unsqueeze(1),
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment