Commit 86a65417 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (ex tests&vllm)

parent 45a060d6
...@@ -44,23 +44,10 @@ try: ...@@ -44,23 +44,10 @@ try:
except ImportError: except ImportError:
from backend_request_func import get_tokenizer from backend_request_func import get_tokenizer
# triton_version = triton.__version__
# if triton_version.startswith("2.1"):
# from triton.common.backend import compute_core_version_key
# elif triton_version.startswith("3.0"):
# from triton.compiler.compiler import triton_key
# else:
# print(f"TRITON version {triton_version} is not specifically handled.")
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
def test_prefix(llm=None, sampling_params=None, prompts=None): def test_prefix(llm=None, sampling_params=None, prompts=None):
# if triton_version.startswith("2.1"):
# version_key = compute_core_version_key()
# if triton_version.startswith("3.0"):
# version_key = triton_key()
start_time = time.time() start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params) llm.generate(prompts, sampling_params=sampling_params)
...@@ -287,4 +274,4 @@ def create_argument_parser(): ...@@ -287,4 +274,4 @@ def create_argument_parser():
if __name__ == "__main__": if __name__ == "__main__":
parser = create_argument_parser() parser = create_argument_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
{
"triton_version": "3.1",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"triton_version": "3.1",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
...@@ -197,7 +197,7 @@ def bench_run( ...@@ -197,7 +197,7 @@ def bench_run(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
...@@ -242,7 +242,7 @@ def bench_run( ...@@ -242,7 +242,7 @@ def bench_run(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
...@@ -520,4 +520,4 @@ if __name__ == "__main__": ...@@ -520,4 +520,4 @@ if __name__ == "__main__":
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
...@@ -10,6 +10,8 @@ from transformers import AutoConfig ...@@ -10,6 +10,8 @@ from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute,
_moe_unpermute_and_reduce,
moe_permute, moe_permute,
moe_unpermute, moe_unpermute,
) )
...@@ -39,6 +41,7 @@ def benchmark_permute( ...@@ -39,6 +41,7 @@ def benchmark_permute(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
use_customized_permute: bool = False,
) -> float: ) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
...@@ -61,14 +64,29 @@ def benchmark_permute( ...@@ -61,14 +64,29 @@ def benchmark_permute(
input_gating.copy_(gating_output[i]) input_gating.copy_(gating_output[i])
def run(): def run():
moe_permute( if use_customized_permute:
qhidden_states, (
a1q_scale=None, permuted_hidden_states,
topk_ids=topk_ids, a1q_scale,
n_expert=num_experts, first_token_off,
expert_map=None, inv_perm_idx,
align_block_size=align_block_size, m_indices,
) ) = moe_permute(
qhidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(qhidden_states, None, topk_ids, num_experts, None, 16)
# JIT compilation & warmup # JIT compilation & warmup
run() run()
...@@ -113,9 +131,11 @@ def benchmark_unpermute( ...@@ -113,9 +131,11 @@ def benchmark_unpermute(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
use_customized_permute: bool = False,
) -> float: ) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8: if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None) qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
...@@ -130,37 +150,78 @@ def benchmark_unpermute( ...@@ -130,37 +150,78 @@ def benchmark_unpermute(
) )
def prepare(): def prepare():
( if use_customized_permute:
permuted_hidden_states, (
_, permuted_hidden_states,
first_token_off, a1q_scale,
inv_perm_idx, first_token_off,
_, inv_perm_idx,
) = moe_permute( m_indices,
qhidden_states, ) = moe_permute(
a1q_scale=None, qhidden_states,
topk_ids=topk_ids, a1q_scale=None,
n_expert=num_experts, topk_ids=topk_ids,
expert_map=None, n_expert=num_experts,
align_block_size=align_block_size, expert_map=None,
) align_block_size=align_block_size,
# convert to fp16/bf16 as gemm output )
return ( # convert to fp16/bf16 as gemm output
permuted_hidden_states.to(dtype), return (
first_token_off, permuted_hidden_states.to(dtype),
inv_perm_idx, first_token_off,
) inv_perm_idx,
m_indices,
)
else:
(
permuted_qhidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(
qhidden_states, None, topk_ids, num_experts, None, block_m=16
)
# convert to fp16/bf16 as gemm output
return (
permuted_qhidden_states.to(dtype),
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
)
def run(input: tuple): def run(input: tuple):
(permuted_hidden_states, first_token_off, inv_perm_idx) = input if use_customized_permute:
output = torch.empty_like(hidden_states) (
moe_unpermute( permuted_hidden_states,
output, first_token_off,
permuted_hidden_states, inv_perm_idx,
topk_weights, m_indices,
inv_perm_idx, ) = input
first_token_off, output = torch.empty_like(hidden_states)
) moe_unpermute(
output,
permuted_hidden_states,
topk_weights,
inv_perm_idx,
first_token_off,
)
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = input
_moe_unpermute_and_reduce(
output_hidden_states,
permuted_hidden_states,
inv_perm,
topk_weights,
True,
)
# JIT compilation & warmup # JIT compilation & warmup
input = prepare() input = prepare()
...@@ -215,7 +276,8 @@ class BenchmarkWorker: ...@@ -215,7 +276,8 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
) -> tuple[float, float]: use_customized_permute: bool = False,
) -> tuple[dict[str, int], float]:
set_random_seed(self.seed) set_random_seed(self.seed)
permute_time = benchmark_permute( permute_time = benchmark_permute(
...@@ -227,6 +289,7 @@ class BenchmarkWorker: ...@@ -227,6 +289,7 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
use_customized_permute=use_customized_permute,
) )
unpermute_time = benchmark_unpermute( unpermute_time = benchmark_unpermute(
num_tokens, num_tokens,
...@@ -237,6 +300,7 @@ class BenchmarkWorker: ...@@ -237,6 +300,7 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
use_customized_permute=use_customized_permute,
) )
return permute_time, unpermute_time return permute_time, unpermute_time
...@@ -283,6 +347,7 @@ def main(args: argparse.Namespace): ...@@ -283,6 +347,7 @@ def main(args: argparse.Namespace):
dtype = torch.float16 if current_platform.is_rocm() else config.dtype dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
...@@ -334,6 +399,7 @@ def main(args: argparse.Namespace): ...@@ -334,6 +399,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_customized_permute,
) )
for batch_size in batch_sizes for batch_size in batch_sizes
], ],
...@@ -353,9 +419,10 @@ if __name__ == "__main__": ...@@ -353,9 +419,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
) )
parser.add_argument("--use-customized-permute", action="store_true")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
...@@ -15,7 +15,6 @@ from vllm.utils.torch_utils import ( ...@@ -15,7 +15,6 @@ from vllm.utils.torch_utils import (
create_kv_caches_with_random, create_kv_caches_with_random,
set_random_seed, set_random_seed,
) )
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -102,10 +101,6 @@ def main( ...@@ -102,10 +101,6 @@ def main(
device=output.device, device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if version == "v12":
sliding_window = ((-1, -1))
logits_soft_cap = 0.0
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -119,23 +114,23 @@ def main( ...@@ -119,23 +114,23 @@ def main(
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
query, query,
key_cache, key_cache,
value_cache, value_cache,
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
seq_lens, seq_lens,
block_size, block_size,
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
k_scale, k_scale,
v_scale, v_scale,
) )
elif version == "v2": elif version == "v2":
if not args.custom_paged_attn: if not args.custom_paged_attn:
ops.paged_attention_v2( ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
...@@ -176,24 +171,6 @@ def main( ...@@ -176,24 +171,6 @@ def main(
k_scale, k_scale,
v_scale, v_scale,
) )
elif version == "v12":
from flash_attn import vllm_flash_attn_with_kvcache
vllm_flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=seq_lens,
block_table=block_tables,
softmax_scale=scale,
causal=True,
window_size=sliding_window,
softcap=logits_soft_cap,
alibi_slopes=alibi_slopes,
return_softmax_lse=False,
k_scale=k_scale,
v_scale=v_scale,
kv_cache_dtype=kv_cache_dtype,
).squeeze(1)
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -223,7 +200,7 @@ if __name__ == "__main__": ...@@ -223,7 +200,7 @@ if __name__ == "__main__":
) )
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
parser.add_argument("--version", type=str, choices=["v1", "v2", "v12"], default="v12") parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
...@@ -234,7 +211,7 @@ if __name__ == "__main__": ...@@ -234,7 +211,7 @@ if __name__ == "__main__":
choices=[64, 80, 96, 112, 120, 128, 192, 256], choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128, default=128,
) )
parser.add_argument("--block-size", type=int, choices=[16, 32, 64], default=64) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
...@@ -248,7 +225,8 @@ if __name__ == "__main__": ...@@ -248,7 +225,8 @@ if __name__ == "__main__":
default="auto", default="auto",
help="Data type for kv cache storage. If 'auto', will use model " help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (hcu) supports fp8 (=fp8_e4m3)") "ROCm (hcu) supports fp8 (=fp8_e4m3)",
)
parser.add_argument( parser.add_argument(
"--custom-paged-attn", action="store_true", help="Use custom paged attention" "--custom-paged-attn", action="store_true", help="Use custom paged attention"
) )
......
...@@ -81,12 +81,6 @@ void indexer_k_quant_and_cache( ...@@ -81,12 +81,6 @@ void indexer_k_quant_and_cache(
int64_t quant_block_size, // quantization block size int64_t quant_block_size, // quantization block size
const std::string& scale_fmt); const std::string& scale_fmt);
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt);
// Extract function to gather quantized K cache // Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache( void cp_gather_indexer_k_quant_cache(
...@@ -94,20 +88,4 @@ void cp_gather_indexer_k_quant_cache( ...@@ -94,20 +88,4 @@ void cp_gather_indexer_k_quant_cache(
torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks] const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1] const torch::Tensor& cu_seq_lens); // [batch_size + 1]
\ No newline at end of file
void read_cache(
torch::Tensor& keys,
torch::Tensor& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
void write_cache_multi_layers(
torch::Tensor& keys,
torch::Tensor& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
...@@ -29,11 +29,6 @@ ...@@ -29,11 +29,6 @@
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
#if defined(__gfx942__)
constexpr float kFp8ScaleDivisor = 224.f;
#else
constexpr float kFp8ScaleDivisor = 448.f;
#endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes, int64_t block_size_in_bytes,
...@@ -389,133 +384,6 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -389,133 +384,6 @@ __global__ void reshape_and_cache_flash_kernel(
} }
} }
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void write_cache_multi_layers_kernel(
scalar_t* __restrict__ keys, // [num_layers, num_tokens, num_heads, head_size]
scalar_t* __restrict__ values, // [num_layers, num_tokens, num_heads, head_size]
int64_t* key_cache_ptrs, // [num_blocks, num_heads, head_size/x,
// block_size, x]
int64_t* value_cache_ptrs, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const int x, const int num_tokens) {
const int layer_idx = blockIdx.x;
const int token_idx = blockIdx.y;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
cache_t* key_cache = reinterpret_cast<cache_t*>(key_cache_ptrs[layer_idx]);
cache_t* value_cache =
reinterpret_cast<cache_t*>(value_cache_ptrs[layer_idx]);
scalar_t* key = keys + layer_idx * num_tokens * key_stride;
scalar_t* value = values + layer_idx * num_tokens * value_stride;
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int64_t tgt_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, 1.0);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, 1.0);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void read_cache_kernel(
scalar_t* __restrict__ keys, // [num_layers, num_tokens, num_heads, head_size]
scalar_t* __restrict__ values, // [num_layers, num_tokens, num_heads, head_size]
int64_t* key_cache_ptrs, // [num_blocks, num_heads, head_size/x,
// block_size, x]
int64_t* value_cache_ptrs, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const int x, const int num_tokens) {
const int layer_idx = blockIdx.x;
const int token_idx = blockIdx.y;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
cache_t* key_cache = reinterpret_cast<cache_t*>(key_cache_ptrs[layer_idx]);
cache_t* value_cache =
reinterpret_cast<cache_t*>(value_cache_ptrs[layer_idx]);
scalar_t* key = keys + layer_idx * num_tokens * key_stride;
scalar_t* value = values + layer_idx * num_tokens * value_stride;
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int64_t src_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t src_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
const int64_t tgt_key_idx = token_idx * key_stride + i;
const int64_t tgt_value_idx = token_idx * value_stride + i;
cache_t tgt_key = key_cache[src_key_idx];
cache_t tgt_value = value_cache[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key[tgt_key_idx] = tgt_key;
value[tgt_value_idx] = tgt_value;
} else {
key[tgt_key_idx] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(tgt_key, 1.0);
value[tgt_value_idx] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(tgt_value, 1.0);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel( __global__ void concat_and_cache_mla_kernel(
...@@ -638,7 +506,8 @@ __global__ void concat_and_cache_ds_mla_kernel( ...@@ -638,7 +506,8 @@ __global__ void concat_and_cache_ds_mla_kernel(
} }
// Compute the scale for the tile // Compute the scale for the tile
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisor, FLT_MIN); float tile_scale = max_abs / 448.f;
tile_scale = fmaxf(tile_scale, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache // The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) { if ((lane_idx == 0) || (lane_idx == 16)) {
...@@ -707,7 +576,11 @@ __global__ void indexer_k_quant_and_cache_kernel( ...@@ -707,7 +576,11 @@ __global__ void indexer_k_quant_and_cache_kernel(
#endif #endif
} }
float scale = fmaxf(amax, 1e-4) / kFp8ScaleDivisor; #if defined(__gfx942__)
float scale = fmaxf(amax, 1e-4) / 224.0f;
#else
float scale = fmaxf(amax, 1e-4) / 448.0f;
#endif
if (use_ue8m0) { if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale))); scale = exp2f(ceilf(log2f(scale)));
...@@ -728,53 +601,6 @@ __global__ void indexer_k_quant_and_cache_kernel( ...@@ -728,53 +601,6 @@ __global__ void indexer_k_quant_and_cache_kernel(
} }
} }
template <typename scalar_t, typename cache_t>
__global__ void indexer_k_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int cache_block_size, // cache block size
const int cache_stride // stride for each token in kv_cache
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
float val = static_cast<float>(k_val_ptr[i]);
if constexpr (std::is_same<cache_t, at::Half>::value ||
std::is_same<cache_t, __half>::value) {
kv_cache[dst_offset + i] = __float2half(val);
} else if constexpr (std::is_same<cache_t, at::BFloat16>::value ||
std::is_same<cache_t, __nv_bfloat16>::value) {
__hip_bfloat16 bf16_val = __float2bfloat16(val);
kv_cache[dst_offset + i] = *reinterpret_cast<at::BFloat16*>(&bf16_val);
} else if constexpr (std::is_same<cache_t, float>::value) {
kv_cache[dst_offset + i] = val;
} else {
kv_cache[dst_offset + i] = static_cast<cache_t>(val);
}
}
}
template <int BLOCK_Y_SIZE> template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel( __global__ void cp_gather_indexer_k_quant_cache_kernel(
const char* __restrict__ kv_cache, // [num_blocks, block_size, const char* __restrict__ kv_cache, // [num_blocks, block_size,
...@@ -998,149 +824,6 @@ void reshape_and_cache_flash( ...@@ -998,149 +824,6 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH); CALL_RESHAPE_AND_CACHE_FLASH);
} }
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_READ_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::read_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(keys.data_ptr()), \
reinterpret_cast<KV_T*>(values.data_ptr()), \
key_cache_ptrs_tensor.data_ptr<int64_t>(), \
value_cache_ptrs_tensor.data_ptr<int64_t>(), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, value_stride, \
num_heads, head_size, block_size, x, num_tokens);
void read_cache(
torch::Tensor& keys, // [num_layers, seq_len, num_heads, head_size]
torch::Tensor& values, // [num_layers, seq_len, num_heads, head_size]
std::vector<torch::Tensor> const& key_caches, // [num_blocks, num_heads, head_size/x, block_size, x]
std::vector<torch::Tensor> const& value_caches, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());
// Create data structures for the kernel.
// Create an array of pointers to the key and value and caches.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
int num_tokens = keys.size(1);
auto kv_dtype = keys.dtype();
torch::Tensor key_cache = key_caches[0];
torch::Tensor value_cache = value_caches[0];
int key_stride = keys.stride(1);
int value_stride = values.stride(1);
int num_heads = value_cache.size(1);
int head_size = value_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor =
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
dim3 grid(num_layers, num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(slot_mapping));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_dtype, kv_cache_dtype,
CALL_READ_CACHE);
}
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_WRITE_CACHE_MULTI_LAYERS(KV_T, CACHE_T, KV_DTYPE) \
vllm::write_cache_multi_layers_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(keys.data_ptr()), \
reinterpret_cast<KV_T*>(values.data_ptr()), \
key_cache_ptrs_tensor.data_ptr<int64_t>(), \
value_cache_ptrs_tensor.data_ptr<int64_t>(), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, value_stride, \
num_heads, head_size, block_size, x, num_tokens);
void write_cache_multi_layers(
torch::Tensor& keys, // [num_layers, seq_len, num_heads, head_size]
torch::Tensor& values, // [num_layers, seq_len, num_heads, head_size]
std::vector<torch::Tensor> const& key_caches, // [num_blocks, num_heads, head_size/x, block_size, x]
std::vector<torch::Tensor> const& value_caches, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());
// Create data structures for the kernel.
// Create an array of pointers to the key and value and caches.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
auto kv_dtype = keys.dtype();
int num_tokens = keys.size(1);
torch::Tensor key_cache = key_caches[0];
torch::Tensor value_cache = value_caches[0];
int key_stride = keys.stride(1);
int value_stride = values.stride(1);
int num_heads = value_cache.size(1);
int head_size = value_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor =
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
dim3 grid(num_layers, num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(slot_mapping));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_dtype, kv_cache_dtype,
CALL_WRITE_CACHE_MULTI_LAYERS);
}
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ #define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \ vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
...@@ -1766,78 +1449,6 @@ void indexer_k_quant_and_cache( ...@@ -1766,78 +1449,6 @@ void indexer_k_quant_and_cache(
CALL_INDEXER_K_QUANT_AND_CACHE); CALL_INDEXER_K_QUANT_AND_CACHE);
} }
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_CACHE(KV_T, CACHE_T) \
vllm::indexer_k_cache_kernel<KV_T, CACHE_T> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, \
cache_block_size, cache_stride);
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + vec_size - 1) / vec_size);
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
k.scalar_type(), "indexer_k_cache", ([&] {
using k_t = scalar_t;
auto kv_cache_type = kv_cache.scalar_type();
if (kv_cache_type == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<float>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::Half>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::BFloat16) {
vllm::indexer_k_cache_kernel<k_t, at::BFloat16>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::BFloat16>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
}
}));
}
// Macro to dispatch the kernel based on the data amount. // Macro to dispatch the kernel based on the data amount.
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ #define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \ vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
......
#include <map>
#include <vector>
#include "cpu_types.hpp"
#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& mapping_pairs,
const int element_num_per_block,
const int layer_num) {
const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset =
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t* target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
}
}
}
template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx >= 0) {
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx =
token_idx * value_stride + head_idx * head_size;
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
const int64_t target_offset =
src_key_idx * block_size + block_offset * x;
for (int i = 0; i < x; ++i) {
target_key_head_ptr[target_offset + i] =
src_key_head_ptr[src_key_idx + i];
}
}
for (int src_value_idx = 0; src_value_idx < head_size;
++src_value_idx) {
const int64_t target_offset =
src_value_idx * block_size + block_offset;
target_value_head_ptr[target_offset] =
src_value_head_ptr[src_value_idx];
}
}
}
}
}
}; // namespace
template <typename scalar_t>
void concat_and_cache_mla_cpu_impl(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_tokens, //
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size //
) {
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
continue;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src,
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
int size, int offset) {
for (int i = 0; i < size; i++) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
dst[dst_idx] = src[src_idx];
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
}
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
const int element_num_per_block = key_caches[0][0].numel();
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
}
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
}
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
TORCH_CHECK(kv_cache_dtype != "fp8");
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
VLLM_DISPATCH_FLOATING_TYPES(
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
concat_and_cache_mla_cpu_impl<scalar_t>(
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
kv_lora_rank, pe_dim, block_size);
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
});
}
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}
void read_cache(
std::vector<torch::Tensor> const& keys,
std::vector<torch::Tensor> const& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype) {
TORCH_CHECK(false, "read_cache is unsupported on CPU.")
}
void write_cache_multi_layers(
std::vector<torch::Tensor> const& keys,
std::vector<torch::Tensor> const& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype) {
TORCH_CHECK(false, "write_cache_multi_layers is unsupported on CPU.")
}
...@@ -360,14 +360,13 @@ void onednn_scaled_mm( ...@@ -360,14 +360,13 @@ void onednn_scaled_mm(
const std::optional<torch::Tensor>& azp, // [M] or [1] const std::optional<torch::Tensor>& azp, // [M] or [1]
const std::optional<torch::Tensor>& azp_adj, // [M] or [1] const std::optional<torch::Tensor>& azp_adj, // [M] or [1]
const std::optional<torch::Tensor>& bias, // [N] const std::optional<torch::Tensor>& bias, // [N]
const torch::Tensor& handler_tensor) { int64_t handler) {
CPU_KERNEL_GUARD_IN(onednn_scaled_mm) CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
TORCH_CHECK(a.dim() == 2); TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.is_contiguous()); TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(c.is_contiguous()); TORCH_CHECK(c.is_contiguous());
W8A8MatMulPrimitiveHandler* ptr = W8A8MatMulPrimitiveHandler* ptr =
reinterpret_cast<W8A8MatMulPrimitiveHandler*>( reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler);
handler_tensor.item<int64_t>());
const int32_t* azp_ptr = nullptr; const int32_t* azp_ptr = nullptr;
if (azp.has_value()) { if (azp.has_value()) {
azp_ptr = azp->data_ptr<int32_t>(); azp_ptr = azp->data_ptr<int32_t>();
...@@ -520,14 +519,13 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b, ...@@ -520,14 +519,13 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
void onednn_mm(torch::Tensor& c, // [M, OC], row-major void onednn_mm(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major const torch::Tensor& a, // [M, IC], row-major
const std::optional<torch::Tensor>& bias, const std::optional<torch::Tensor>& bias, int64_t handler) {
const torch::Tensor& handler_tensor) {
CPU_KERNEL_GUARD_IN(onednn_mm) CPU_KERNEL_GUARD_IN(onednn_mm)
TORCH_CHECK(a.dim() == 2); TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.stride(-1) == 1); TORCH_CHECK(a.stride(-1) == 1);
TORCH_CHECK(c.stride(-1) == 1); TORCH_CHECK(c.stride(-1) == 1);
MatMulPrimitiveHandler* ptr = MatMulPrimitiveHandler* ptr =
reinterpret_cast<MatMulPrimitiveHandler*>(handler_tensor.item<int64_t>()); reinterpret_cast<MatMulPrimitiveHandler*>(handler);
// ACL matmuls expect contiguous source tensors // ACL matmuls expect contiguous source tensors
#ifdef VLLM_USE_ACL #ifdef VLLM_USE_ACL
...@@ -569,4 +567,4 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major ...@@ -569,4 +567,4 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major
ptr->execute(exec_args); ptr->execute(exec_args);
}); });
} }
\ No newline at end of file
...@@ -19,14 +19,13 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, ...@@ -19,14 +19,13 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& azp, const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& azp_adj, const std::optional<torch::Tensor>& azp_adj,
const std::optional<torch::Tensor>& bias, const std::optional<torch::Tensor>& bias,
const torch::Tensor& handler_tensor); int64_t handler);
int64_t create_onednn_mm_handler(const torch::Tensor& b, int64_t create_onednn_mm_handler(const torch::Tensor& b,
int64_t primitive_cache_size); int64_t primitive_cache_size);
void onednn_mm(torch::Tensor& c, const torch::Tensor& a, void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& bias, const std::optional<torch::Tensor>& bias, int64_t handler);
const torch::Tensor& handler_tensor);
bool is_onednn_acl_supported(); bool is_onednn_acl_supported();
...@@ -197,7 +196,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -197,7 +196,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// oneDNN GEMM // oneDNN GEMM
ops.def( ops.def(
"onednn_mm(Tensor! c, Tensor a, Tensor? bias, " "onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
"Tensor handler_tensor) -> ()"); "int handler) -> ()");
ops.impl("onednn_mm", torch::kCPU, &onednn_mm); ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
// Check if oneDNN was built with ACL backend // Check if oneDNN was built with ACL backend
...@@ -213,7 +212,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -213,7 +212,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// oneDNN scaled_mm for W8A8 with static per-tensor activation quantization // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
ops.def( ops.def(
"onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
"Tensor? azp_adj, Tensor? bias, Tensor handler_tensor) -> ()"); "Tensor? azp_adj, Tensor? bias, int handler) -> ()");
ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
// Compute int8 quantized tensor for given scaling factor. // Compute int8 quantized tensor for given scaling factor.
...@@ -296,9 +295,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -296,9 +295,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
&cpu_attention_with_kv_cache); &cpu_attention_with_kv_cache);
// placeholders // placeholders
// ops.def("static_scaled_fp8_quant() -> ()", placeholder_op); ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
// ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op); ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
// ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op); ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
// WNA16 // WNA16
#if defined(__AVX512F__) #if defined(__AVX512F__)
...@@ -336,4 +335,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) { ...@@ -336,4 +335,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache); cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
} }
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
\ No newline at end of file
...@@ -755,22 +755,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -755,22 +755,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash); &reshape_and_cache_flash);
// read key and value form kv cache
cache_ops.def(
"read_cache(Tensor keys, Tensor values,"
" Tensor[]! key_caches, Tensor[]! value_caches,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()");
cache_ops.impl("read_cache", torch::kCUDA, &read_cache);
// write multi-layers key and value to kv cache
cache_ops.def(
"write_cache_multi_layers(Tensor keys, Tensor values,"
" Tensor[]! key_caches, Tensor[]! value_caches,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()");
cache_ops.impl("write_cache_multi_layers", torch::kCUDA, &write_cache_multi_layers);
// Concat kv_c and k_pe and cache them. // Concat kv_c and k_pe and cache them.
cache_ops.def( cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe," "concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
...@@ -833,14 +817,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -833,14 +817,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache); &indexer_k_quant_and_cache);
cache_ops.def(
"indexer_k_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_cache", torch::kCUDA,
&indexer_k_cache);
cache_ops.def( cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! " "cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"); "dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
......
...@@ -47,10 +47,6 @@ You can tune the performance by adjusting `max_num_batched_tokens`: ...@@ -47,10 +47,6 @@ You can tune the performance by adjusting `max_num_batched_tokens`:
- For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs. - For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs.
- If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes). - If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes).
!!! warning
When chunked prefill is disabled, `max_num_batched_tokens` must be greater than `max_model_len`.
In that case, if `max_num_batched_tokens < max_model_len`, vLLM may crash at server start‑up.
```python ```python
from vllm import LLM from vllm import LLM
...@@ -289,4 +285,4 @@ Based on the configuration, the content of the multi-modal caches on `P0` and `P ...@@ -289,4 +285,4 @@ Based on the configuration, the content of the multi-modal caches on `P0` and `P
| N/A | Disabled | N/A | N/A | N/A | `0` | | N/A | Disabled | N/A | N/A | N/A | `0` |
K: Stores the hashes of multi-modal items K: Stores the hashes of multi-modal items
V: Stores the processed tensor data of multi-modal items V: Stores the processed tensor data of multi-modal items
\ No newline at end of file
...@@ -71,7 +71,7 @@ class MyModel(nn.Module): ...@@ -71,7 +71,7 @@ class MyModel(nn.Module):
```python ```python
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -145,4 +145,4 @@ Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linea ...@@ -145,4 +145,4 @@ Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linea
It is also worth noting that we should update `MAMBA_TYPE_TO_BACKEND_MAP` and `MambaAttentionBackendEnum` in [`registry.py`](../../../vllm/v1/attention/backends/registry.py) when adding a new mamba backend. It is also worth noting that we should update `MAMBA_TYPE_TO_BACKEND_MAP` and `MambaAttentionBackendEnum` in [`registry.py`](../../../vllm/v1/attention/backends/registry.py) when adding a new mamba backend.
Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it.
Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this. Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this.
The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended. The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended.
\ No newline at end of file
...@@ -43,73 +43,28 @@ Further update the model as follows: ...@@ -43,73 +43,28 @@ Further update the model as follows:
) )
``` ```
- Remove the embedding part from the [forward][torch.nn.Module.forward] method: - Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
- Move the multi-modal embedding to [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal].
- The text embedding and embedding merge are handled automatically by a default implementation of [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids]. It does not need to be overridden in most cases.
```diff
def forward(
self,
input_ids: torch.Tensor | None,
- pixel_values: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- )
- special_image_mask = self.get_placeholder_mask(
- input_ids,
- inputs_embeds=inputs_embeds,
- image_features=image_features,
- )
- inputs_embeds = inputs_embeds.masked_scatter(
- special_image_mask,
- image_features,
- )
hidden_states = self.language_model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
...
+ def embed_multimodal(
+ self,
+ pixel_values: torch.Tensor,
+ ) -> MultiModalEmbeddings | None:
+ return self.get_image_features(
+ pixel_values=pixel_values,
+ )
```
Below we provide a boilerplate of a typical implementation pattern of [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal], but feel free to adjust it to your own needs.
```python ??? code
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features)
def embed_multimodal( ```python
self, def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
**kwargs: object, image_features = self.vision_encoder(image_input)
) -> MultiModalEmbeddings | None: return self.multi_modal_projector(image_features)
# Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs) def embed_multimodal(
if image_input is None: self,
return None **kwargs: object,
) -> MultiModalEmbeddings | None:
# Run multimodal inputs through encoder and projector # Validate the multimodal input keyword arguments
vision_embeddings = self._process_image_input(image_input) image_input = self._parse_and_validate_image_input(**kwargs)
return vision_embeddings if image_input is None:
``` return None
# Run multimodal inputs through encoder and projector
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
```
!!! important !!! important
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
...@@ -884,4 +839,4 @@ Examples: ...@@ -884,4 +839,4 @@ Examples:
- DeepSeek-VL2: [vllm/model_executor/models/deepseek_vl2.py](../../../vllm/model_executor/models/deepseek_vl2.py) - DeepSeek-VL2: [vllm/model_executor/models/deepseek_vl2.py](../../../vllm/model_executor/models/deepseek_vl2.py)
- InternVL: [vllm/model_executor/models/internvl.py](../../../vllm/model_executor/models/internvl.py) - InternVL: [vllm/model_executor/models/internvl.py](../../../vllm/model_executor/models/internvl.py)
- Qwen-VL: [vllm/model_executor/models/qwen_vl.py](../../../vllm/model_executor/models/qwen_vl.py) - Qwen-VL: [vllm/model_executor/models/qwen_vl.py](../../../vllm/model_executor/models/qwen_vl.py)
\ No newline at end of file
...@@ -10,7 +10,7 @@ receives a request for a LoRA adapter that hasn't been loaded yet, the resolver ...@@ -10,7 +10,7 @@ receives a request for a LoRA adapter that hasn't been loaded yet, the resolver
to locate and load the adapter from their configured storage locations. This enables: to locate and load the adapter from their configured storage locations. This enables:
- **Dynamic LoRA Loading**: Load adapters on-demand without server restarts - **Dynamic LoRA Loading**: Load adapters on-demand without server restarts
- **Multiple Storage Backends**: Support for filesystem, S3, and custom backends. The built-in `lora_filesystem_resolver` requires a local storage path, while the built-in `hf_hub_resolver` will pull LoRA adapters from Huggingface Hub and proceed in an identical manner. In general, custom resolvers can be implemented to fetch from any source. - **Multiple Storage Backends**: Support for filesystem, S3, and custom backends. The built-in `lora_filesystem_resolver` requires a local storage path, but custom resolvers can be implemented to fetch from any source.
- **Automatic Discovery**: Seamless integration with existing LoRA workflows - **Automatic Discovery**: Seamless integration with existing LoRA workflows
- **Scalable Deployment**: Centralized adapter management across multiple vLLM instances - **Scalable Deployment**: Centralized adapter management across multiple vLLM instances
...@@ -217,4 +217,4 @@ To implement your own resolver plugin: ...@@ -217,4 +217,4 @@ To implement your own resolver plugin:
config = json.load(f) config = json.load(f)
print('Config valid:', config) print('Config valid:', config)
" "
``` ```
\ No newline at end of file
...@@ -36,7 +36,8 @@ th { ...@@ -36,7 +36,8 @@ th {
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] | | pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | | deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] | | flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] |
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | | MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | | BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
...@@ -111,4 +112,4 @@ The following table shows "families" of modular kernels that are intended to wor ...@@ -111,4 +112,4 @@ The following table shows "families" of modular kernels that are intended to wor
|---------|-----------------------------------------|----------------------------------------------| |---------|-----------------------------------------|----------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` | | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` | | deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
\ No newline at end of file
...@@ -159,12 +159,10 @@ Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adap ...@@ -159,12 +159,10 @@ Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adap
You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds. You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.
You can either install existing plugins or implement your own. By default, vLLM comes with a [resolver plugin to load LoRA adapters from a local directory, as well as a resolver plugin to load LoRA adapters from repositories on Hugging Face Hub](https://github.com/vllm-project/vllm/tree/main/vllm/plugins/lora_resolvers) You can either install existing plugins or implement your own. By default, vLLM comes with a [resolver plugin to load LoRA adapters from a local directory.](https://github.com/vllm-project/vllm/tree/main/vllm/plugins/lora_resolvers)
To enable either of these resolvers, you must `set VLLM_ALLOW_RUNTIME_LORA_UPDATING` to True. To enable this resolver, set `VLLM_ALLOW_RUNTIME_LORA_UPDATING` to True, set `VLLM_PLUGINS` to include `lora_filesystem_resolver`, and then set `VLLM_LORA_RESOLVER_CACHE_DIR` to a local directory. When vLLM receives a request using a LoRA adapter `foobar`,
it will first look in the local directory for a directory `foobar`, and attempt to load the contents of that directory as a LoRA adapter. If successful, the request will complete as normal and
- To leverage a local directory, set `VLLM_PLUGINS` to include `lora_filesystem_resolver` and set `VLLM_LORA_RESOLVER_CACHE_DIR` to a local directory. When vLLM receives a request using a LoRA adapter `foobar`, that adapter will then be available for normal use on the server.
it will first look in the local directory for a directory `foobar`, and attempt to load the contents of that directory as a LoRA adapter. If successful, the request will complete as normal and that adapter will then be available for normal use on the server.
- To leverage repositories on Hugging Face Hub, set `VLLM_PLUGINS` to include `lora_hf_hub_resolver` and set `VLLM_LORA_RESOLVER_HF_REPO_LIST` to a comma separated list of repository IDs on Hugging Face Hub. When vLLM receives a request for the LoRA adapter `my/repo/subpath`, it will download the adapter at the `subpath` of `my/repo` if it exists and contains an `adapter_config.json`, then build a request to the cached dir for the adapter, similar to the `lora_filesystem_resolver`. Please note that enabling remote downloads is insecure and not intended for use in production environments.
Alternatively, follow these example steps to implement your own plugin: Alternatively, follow these example steps to implement your own plugin:
...@@ -387,4 +385,4 @@ vllm serve model --enable-lora --max-lora-rank 64 ...@@ -387,4 +385,4 @@ vllm serve model --enable-lora --max-lora-rank 64
# Bad: unnecessarily high, wastes memory # Bad: unnecessarily high, wastes memory
vllm serve model --enable-lora --max-lora-rank 256 vllm serve model --enable-lora --max-lora-rank 256
``` ```
\ No newline at end of file
...@@ -675,7 +675,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -675,7 +675,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | | `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | | `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + I<sup>E+</sup> | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ | | `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + I<sup>E+</sup> | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
...@@ -849,4 +848,4 @@ We have the following levels of testing for models: ...@@ -849,4 +848,4 @@ We have the following levels of testing for models:
1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test.
2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](../../tests) and [examples](../../examples) for the models that have passed this test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](../../tests) and [examples](../../examples) for the models that have passed this test.
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. 4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
This example shows how to use vLLM for running offline inference This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen3-Omni (thinker only). with the correct prompt format on Qwen2.5-Omni (thinker only).
""" """
from typing import NamedTuple from typing import NamedTuple
...@@ -112,51 +112,23 @@ def get_multi_audios_query() -> QueryResult: ...@@ -112,51 +112,23 @@ def get_multi_audios_query() -> QueryResult:
) )
def get_multi_images_query() -> QueryResult:
question = "What are the differences between these two images?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": [
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
],
},
},
limit_mm_per_prompt={
"image": 2,
},
)
query_map = { query_map = {
"mixed_modalities": get_mixed_modalities_query, "mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query, "use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query, "multi_audios": get_multi_audios_query,
"multi_images": get_multi_images_query,
} }
def main(args): def main(args):
model_name = args.model model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
query_result = query_map[args.query_type]() query_result = query_map[args.query_type]()
llm = LLM( llm = LLM(
model=model_name, model=model_name,
max_model_len=args.max_model_len, max_model_len=12800,
max_num_seqs=5, max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt, limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed, seed=args.seed,
tensor_parallel_size=args.tensor_parallel_size,
gpu_memory_utilization=args.gpu_memory_utilization,
) )
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
...@@ -189,35 +161,10 @@ def parse_args(): ...@@ -189,35 +161,10 @@ def parse_args():
default=0, default=0,
help="Set the seed when initializing `vllm.LLM`.", help="Set the seed when initializing `vllm.LLM`.",
) )
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
help="Model name or path.",
)
parser.add_argument(
"--tensor-parallel-size",
"-tp",
type=int,
default=1,
help="Tensor parallel size for distributed inference.",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.9,
help="GPU memory utilization (0.0 to 1.0).",
)
parser.add_argument(
"--max-model-len",
type=int,
default=12800,
help="Maximum model context length.",
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment